From 64c24ab937fdbaeb15df2105c7957ffe614a6d03 Mon Sep 17 00:00:00 2001 From: shazbert Date: Sun, 14 Jul 2024 16:13:38 +1000 Subject: [PATCH 001/138] gateio: Add multi asset websocket support WIP. --- exchanges/gateio/gateio_test.go | 28 +-- exchanges/gateio/gateio_websocket.go | 52 ++--- exchanges/gateio/gateio_wrapper.go | 133 +++++++++--- .../gateio/gateio_ws_delivery_futures.go | 129 ++++-------- exchanges/gateio/gateio_ws_futures.go | 147 +++++--------- exchanges/gateio/gateio_ws_option.go | 48 ++--- exchanges/stream/stream_types.go | 9 + exchanges/stream/websocket.go | 192 ++++++++++++------ exchanges/stream/websocket_connection.go | 36 +++- exchanges/stream/websocket_test.go | 3 - exchanges/stream/websocket_types.go | 2 + 11 files changed, 427 insertions(+), 352 deletions(-) diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index e3d075be9ee..ea9bd9d82aa 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -2548,7 +2548,7 @@ const wsTickerPushDataJSON = `{"time": 1606291803, "channel": "spot.tickers", "e func TestWsTickerPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsTickerPushDataJSON)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsTickerPushDataJSON)); err != nil { t.Errorf("%s websocket ticker push data error: %v", g.Name, err) } } @@ -2557,7 +2557,7 @@ const wsTradePushDataJSON = `{ "time": 1606292218, "channel": "spot.trades", "ev func TestWsTradePushData(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsTradePushDataJSON)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsTradePushDataJSON)); err != nil { t.Errorf("%s websocket trade push data error: %v", g.Name, err) } } @@ -2566,7 +2566,7 @@ const wsCandlestickPushDataJSON = `{"time": 1606292600, "channel": "spot.candles func TestWsCandlestickPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsCandlestickPushDataJSON)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsCandlestickPushDataJSON)); err != nil { t.Errorf("%s websocket candlestick push data error: %v", g.Name, err) } } @@ -2575,7 +2575,7 @@ const wsOrderbookTickerJSON = `{"time": 1606293275, "channel": "spot.book_ticker func TestWsOrderbookTickerPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsOrderbookTickerJSON)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsOrderbookTickerJSON)); err != nil { t.Errorf("%s websocket orderbook push data error: %v", g.Name, err) } } @@ -2587,11 +2587,11 @@ const ( func TestWsOrderbookSnapshotPushData(t *testing.T) { t.Parallel() - err := g.wsHandleData([]byte(wsOrderbookSnapshotPushDataJSON)) + err := g.WsHandleSpotData(context.Background(), []byte(wsOrderbookSnapshotPushDataJSON)) if err != nil { t.Errorf("%s websocket orderbook snapshot push data error: %v", g.Name, err) } - if err = g.wsHandleData([]byte(wsOrderbookUpdatePushDataJSON)); err != nil { + if err = g.WsHandleSpotData(context.Background(), []byte(wsOrderbookUpdatePushDataJSON)); err != nil { t.Errorf("%s websocket orderbook update push data error: %v", g.Name, err) } } @@ -2600,7 +2600,7 @@ const wsSpotOrderPushDataJSON = `{"time": 1605175506, "channel": "spot.orders", func TestWsPushOrders(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsSpotOrderPushDataJSON)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsSpotOrderPushDataJSON)); err != nil { t.Errorf("%s websocket orders push data error: %v", g.Name, err) } } @@ -2609,7 +2609,7 @@ const wsUserTradePushDataJSON = `{"time": 1605176741, "channel": "spot.usertrade func TestWsUserTradesPushDataJSON(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsUserTradePushDataJSON)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsUserTradePushDataJSON)); err != nil { t.Errorf("%s websocket users trade push data error: %v", g.Name, err) } } @@ -2618,7 +2618,7 @@ const wsBalancesPushDataJSON = `{"time": 1605248616, "channel": "spot.balances", func TestBalancesPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsBalancesPushDataJSON)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsBalancesPushDataJSON)); err != nil { t.Errorf("%s websocket balances push data error: %v", g.Name, err) } } @@ -2627,7 +2627,7 @@ const wsMarginBalancePushDataJSON = `{"time": 1605248616, "channel": "spot.fundi func TestMarginBalancePushData(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsMarginBalancePushDataJSON)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsMarginBalancePushDataJSON)); err != nil { t.Errorf("%s websocket margin balance push data error: %v", g.Name, err) } } @@ -2636,7 +2636,7 @@ const wsCrossMarginBalancePushDataJSON = `{"time": 1605248616,"channel": "spot.c func TestCrossMarginBalancePushData(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsCrossMarginBalancePushDataJSON)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsCrossMarginBalancePushDataJSON)); err != nil { t.Errorf("%s websocket cross margin balance push data error: %v", g.Name, err) } } @@ -2645,7 +2645,7 @@ const wsCrossMarginBalanceLoan = `{ "time":1658289372, "channel":"spot.cross_loa func TestCrossMarginBalanceLoan(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsCrossMarginBalanceLoan)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsCrossMarginBalanceLoan)); err != nil { t.Errorf("%s websocket cross margin loan push data error: %v", g.Name, err) } } @@ -2963,9 +2963,9 @@ func TestFuturesCandlestickPushData(t *testing.T) { } } -func TestGenerateDefaultSubscriptions(t *testing.T) { +func TestGenerateDefaultSubscriptionsSpot(t *testing.T) { t.Parallel() - if _, err := g.GenerateDefaultSubscriptions(); err != nil { + if _, err := g.GenerateDefaultSubscriptionsSpot(); err != nil { t.Error(err) } } diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index a5b7b43f71c..cefd40931b4 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -58,7 +58,7 @@ var defaultSubscriptions = []string{ var fetchedCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsConnect initiates a websocket connection -func (g *Gateio) WsConnect() error { +func (g *Gateio) WsConnectSpot(ctx context.Context, conn stream.Connection) error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { return stream.ErrWebsocketNotEnabled } @@ -66,7 +66,7 @@ func (g *Gateio) WsConnect() error { if err != nil { return err } - err = g.Websocket.Conn.Dial(&websocket.Dialer{}, http.Header{}) + err = conn.DialContext(ctx, &websocket.Dialer{}, http.Header{}) if err != nil { return err } @@ -74,14 +74,12 @@ func (g *Gateio) WsConnect() error { if err != nil { return err } - g.Websocket.Conn.SetupPingHandler(stream.PingHandler{ + conn.SetupPingHandler(stream.PingHandler{ Websocket: true, Delay: time.Second * 15, Message: pingMessage, MessageType: websocket.TextMessage, }) - g.Websocket.Wg.Add(1) - go g.wsReadConnData() return nil } @@ -94,22 +92,8 @@ func (g *Gateio) generateWsSignature(secret, event, channel string, dtime time.T return hex.EncodeToString(mac.Sum(nil)), nil } -// wsReadConnData receives and passes on websocket messages for processing -func (g *Gateio) wsReadConnData() { - defer g.Websocket.Wg.Done() - for { - resp := g.Websocket.Conn.ReadMessage() - if resp.Raw == nil { - return - } - err := g.wsHandleData(resp.Raw) - if err != nil { - g.Websocket.DataHandler <- err - } - } -} - -func (g *Gateio) wsHandleData(respRaw []byte) error { +// WsHandleSpotData handles spot data +func (g *Gateio) WsHandleSpotData(ctx context.Context, respRaw []byte) error { var push WsResponse err := json.Unmarshal(respRaw, &push) if err != nil { @@ -625,7 +609,7 @@ func (g *Gateio) processCrossMarginLoans(data []byte) error { } // GenerateDefaultSubscriptions returns default subscriptions -func (g *Gateio) GenerateDefaultSubscriptions() (subscription.List, error) { +func (g *Gateio) GenerateDefaultSubscriptionsSpot() (subscription.List, error) { channelsToSubscribe := defaultSubscriptions if g.Websocket.CanUseAuthenticatedEndpoints() { channelsToSubscribe = append(channelsToSubscribe, []string{ @@ -690,14 +674,14 @@ func (g *Gateio) GenerateDefaultSubscriptions() (subscription.List, error) { } // handleSubscription sends a websocket message to receive data from the channel -func (g *Gateio) handleSubscription(event string, channelsToSubscribe subscription.List) error { - payloads, err := g.generatePayload(event, channelsToSubscribe) +func (g *Gateio) handleSubscription(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) error { + payloads, err := g.generatePayload(ctx, conn, event, channelsToSubscribe) if err != nil { return err } var errs error for k := range payloads { - result, err := g.Websocket.Conn.SendMessageReturnResponse(payloads[k].ID, payloads[k]) + result, err := conn.SendMessageReturnResponse(payloads[k].ID, payloads[k]) if err != nil { errs = common.AppendError(errs, err) continue @@ -723,14 +707,14 @@ func (g *Gateio) handleSubscription(event string, channelsToSubscribe subscripti return errs } -func (g *Gateio) generatePayload(event string, channelsToSubscribe subscription.List) ([]WsInput, error) { +func (g *Gateio) generatePayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { if len(channelsToSubscribe) == 0 { return nil, errors.New("cannot generate payload, no channels supplied") } var creds *account.Credentials var err error if g.Websocket.CanUseAuthenticatedEndpoints() { - creds, err = g.GetCredentials(context.TODO()) + creds, err = g.GetCredentials(ctx) if err != nil { return nil, err } @@ -816,7 +800,7 @@ func (g *Gateio) generatePayload(event string, channelsToSubscribe subscription. } payload := WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), + ID: conn.GenerateMessageID(false), Event: event, Channel: channelsToSubscribe[i].Channel, Payload: params, @@ -842,14 +826,14 @@ func (g *Gateio) generatePayload(event string, channelsToSubscribe subscription. return payloads, nil } -// Subscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) Subscribe(channelsToUnsubscribe subscription.List) error { - return g.handleSubscription("subscribe", channelsToUnsubscribe) +// SpotSubscribe sends a websocket message to stop receiving data from the channel +func (g *Gateio) SpotSubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { + return g.handleSubscription(ctx, conn, "subscribe", channelsToUnsubscribe) } -// Unsubscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) Unsubscribe(channelsToUnsubscribe subscription.List) error { - return g.handleSubscription("unsubscribe", channelsToUnsubscribe) +// SpotUnsubscribe sends a websocket message to stop receiving data from the channel +func (g *Gateio) SpotUnsubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { + return g.handleSubscription(ctx, conn, "unsubscribe", channelsToUnsubscribe) } func (g *Gateio) listOfAssetsCurrencyPairEnabledFor(cp currency.Pair) map[asset.Item]bool { diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index ad59fe38c23..92f03ef74c6 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -27,6 +27,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" + "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" "github.com/thrasher-corp/gocryptotrader/log" @@ -159,18 +160,18 @@ func (g *Gateio) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - err = g.DisableAssetWebsocketSupport(asset.Futures) - if err != nil { - log.Errorln(log.ExchangeSys, err) - } - err = g.DisableAssetWebsocketSupport(asset.DeliveryFutures) - if err != nil { - log.Errorln(log.ExchangeSys, err) - } - err = g.DisableAssetWebsocketSupport(asset.Options) - if err != nil { - log.Errorln(log.ExchangeSys, err) - } + // err = g.DisableAssetWebsocketSupport(asset.Futures) + // if err != nil { + // log.Errorln(log.ExchangeSys, err) + // } + // err = g.DisableAssetWebsocketSupport(asset.DeliveryFutures) + // if err != nil { + // log.Errorln(log.ExchangeSys, err) + // } + // err = g.DisableAssetWebsocketSupport(asset.Options) + // if err != nil { + // log.Errorln(log.ExchangeSys, err) + // } g.API.Endpoints = g.NewEndpoints() err = g.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: gateioTradeURL, @@ -208,26 +209,110 @@ func (g *Gateio) Setup(exch *config.Exchange) error { } err = g.Websocket.Setup(&stream.WebsocketSetup{ - ExchangeConfig: exch, - DefaultURL: gateioWebsocketEndpoint, - RunningURL: wsRunningURL, - Connector: g.WsConnect, - Subscriber: g.Subscribe, - Unsubscriber: g.Unsubscribe, - GenerateSubscriptions: g.GenerateDefaultSubscriptions, - Features: &g.Features.Supports.WebsocketCapabilities, - FillsFeed: g.Features.Enabled.FillsFeed, - TradeFeed: g.Features.Enabled.TradeFeed, + ExchangeConfig: exch, + DefaultURL: gateioWebsocketEndpoint, + RunningURL: wsRunningURL, + Features: &g.Features.Supports.WebsocketCapabilities, + FillsFeed: g.Features.Enabled.FillsFeed, + TradeFeed: g.Features.Enabled.TradeFeed, }) if err != nil { return err } - return g.Websocket.SetupNewConnection(stream.ConnectionSetup{ - URL: gateioWebsocketEndpoint, + // Spot connection + err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + URL: gateioWebsocketEndpoint, + RateLimit: gateioWebsocketRateLimit, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: g.WsHandleSpotData, + Subscriber: g.SpotSubscribe, + Unsubscriber: g.SpotUnsubscribe, + GenerateSubscriptions: g.GenerateDefaultSubscriptionsSpot, + Connector: g.WsConnectSpot, + Enabled: func() bool { return g.CheckWebsocketEnabled(asset.Spot) }, + }) + if err != nil { + return err + } + // Futures connection - USDT margined + err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + URL: futuresWebsocketUsdtURL, + RateLimit: gateioWebsocketRateLimit, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: func(ctx context.Context, incoming []byte) error { + return g.WsHandleFuturesData(ctx, incoming, asset.Futures) + }, + Subscriber: g.FuturesSubscribe, + Unsubscriber: g.FuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.USDT) }, + Connector: g.WsFuturesConnect, + Enabled: func() bool { return g.CheckWebsocketEnabled(asset.Futures) }, + }) + if err != nil { + return err + } + + // Futures connection - BTC margined + err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + URL: futuresWebsocketBtcURL, RateLimit: gateioWebsocketRateLimit, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: func(ctx context.Context, incoming []byte) error { + return g.WsHandleFuturesData(ctx, incoming, asset.Futures) + }, + Subscriber: g.FuturesSubscribe, + Unsubscriber: g.FuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.BTC) }, + Connector: g.WsFuturesConnect, + Enabled: func() bool { return g.CheckWebsocketEnabled(asset.Futures) }, }) + if err != nil { + return err + } + + // Futures connection - Delivery - USDT margined + err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + URL: deliveryRealUSDTTradingURL, + RateLimit: gateioWebsocketRateLimit, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: func(ctx context.Context, incoming []byte) error { + return g.WsHandleFuturesData(ctx, incoming, asset.DeliveryFutures) + }, + Subscriber: g.DeliveryFuturesSubscribe, + Unsubscriber: g.DeliveryFuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateDeliveryFuturesDefaultSubscriptions(currency.BTC) }, + Connector: g.WsDeliveryFuturesConnect, + Enabled: func() bool { return true }, + }) + if err != nil { + return err + } + + // TODO: Add BTC margined delivery futures. + + // Futures connection - Options + return g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + URL: optionsWebsocketURL, + RateLimit: gateioWebsocketRateLimit, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: g.WsHandleOptionsData, + Subscriber: g.OptionsSubscribe, + Unsubscriber: g.OptionsUnsubscribe, + GenerateSubscriptions: g.GenerateOptionsDefaultSubscriptions, + Connector: g.WsOptionsConnect, + Enabled: func() bool { return true }, + }) +} + +// CheckWebsocketEnabled checks if the websocket is enabled for an individual asset +func (g *Gateio) CheckWebsocketEnabled(a asset.Item) bool { + err := g.CurrencyPairs.IsAssetEnabled(asset.Futures) + return err == nil && g.AssetWebsocketSupport.IsAssetWebsocketSupported(a) } // UpdateTicker updates and returns the ticker for a currency pair diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_ws_delivery_futures.go index b9242981033..5bb71b20dd5 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_ws_delivery_futures.go @@ -18,7 +18,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" - "github.com/thrasher-corp/gocryptotrader/log" ) const ( @@ -38,13 +37,10 @@ var defaultDeliveryFuturesSubscriptions = []string{ futuresCandlesticksChannel, } -// responseDeliveryFuturesStream a channel thought which the data coming from the two websocket connection will go through. -var responseDeliveryFuturesStream = make(chan stream.Response) - var fetchedFuturesCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsDeliveryFuturesConnect initiates a websocket connection for delivery futures account -func (g *Gateio) WsDeliveryFuturesConnect() error { +func (g *Gateio) WsDeliveryFuturesConnect(ctx context.Context, conn stream.Connection) error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { return stream.ErrWebsocketNotEnabled } @@ -53,45 +49,19 @@ func (g *Gateio) WsDeliveryFuturesConnect() error { return err } var dialer websocket.Dialer - err = g.Websocket.SetWebsocketURL(deliveryRealUSDTTradingURL, false, true) - if err != nil { - return err - } - err = g.Websocket.Conn.Dial(&dialer, http.Header{}) - if err != nil { - return err - } - err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ - URL: deliveryRealBTCTradingURL, - RateLimit: gateioWebsocketRateLimit, - ResponseCheckTimeout: g.Config.WebsocketResponseCheckTimeout, - ResponseMaxLimit: g.Config.WebsocketResponseMaxLimit, - Authenticated: true, - }) - if err != nil { - return err - } - err = g.Websocket.AuthConn.Dial(&dialer, http.Header{}) + err = conn.DialContext(ctx, &dialer, http.Header{}) if err != nil { return err } - g.Websocket.Wg.Add(3) - go g.wsReadDeliveryFuturesData() - go g.wsFunnelDeliveryFuturesConnectionData(g.Websocket.Conn) - go g.wsFunnelDeliveryFuturesConnectionData(g.Websocket.AuthConn) - if g.Verbose { - log.Debugf(log.ExchangeSys, "successful connection to %v\n", - g.Websocket.GetWebsocketURL()) - } pingMessage, err := json.Marshal(WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), - Time: time.Now().Unix(), + ID: conn.GenerateMessageID(false), + Time: time.Now().Unix(), // TODO: func for dynamic time Channel: futuresPingChannel, }) if err != nil { return err } - g.Websocket.Conn.SetupPingHandler(stream.PingHandler{ + conn.SetupPingHandler(stream.PingHandler{ Websocket: true, Delay: time.Second * 5, MessageType: websocket.PingMessage, @@ -100,49 +70,8 @@ func (g *Gateio) WsDeliveryFuturesConnect() error { return nil } -// wsReadDeliveryFuturesData read coming messages thought the websocket connection and pass the data to wsHandleFuturesData for further process. -func (g *Gateio) wsReadDeliveryFuturesData() { - defer g.Websocket.Wg.Done() - for { - select { - case <-g.Websocket.ShutdownC: - select { - case resp := <-responseDeliveryFuturesStream: - err := g.wsHandleFuturesData(resp.Raw, asset.DeliveryFutures) - if err != nil { - select { - case g.Websocket.DataHandler <- err: - default: - log.Errorf(log.WebsocketMgr, "%s websocket handle data error: %v", g.Name, err) - } - } - default: - } - return - case resp := <-responseDeliveryFuturesStream: - err := g.wsHandleFuturesData(resp.Raw, asset.DeliveryFutures) - if err != nil { - g.Websocket.DataHandler <- err - } - } - } -} - -// wsFunnelDeliveryFuturesConnectionData receives data from multiple connection and pass the data -// to wsRead through a channel responseStream -func (g *Gateio) wsFunnelDeliveryFuturesConnectionData(ws stream.Connection) { - defer g.Websocket.Wg.Done() - for { - resp := ws.ReadMessage() - if resp.Raw == nil { - return - } - responseDeliveryFuturesStream <- stream.Response{Raw: resp.Raw} - } -} - // GenerateDeliveryFuturesDefaultSubscriptions returns delivery futures default subscriptions params. -func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() (subscription.List, error) { +func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions(_ currency.Code) (subscription.List, error) { _, err := g.GetCredentials(context.Background()) if err != nil { g.Websocket.SetCanUseAuthenticatedEndpoints(false) @@ -160,6 +89,27 @@ func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() (subscription.Lis if err != nil { return nil, err } + + // switch { + // case settlement.Equal(currency.USDT): + // pairs, err = pairs.GetPairsByQuote(currency.USDT) + // if err != nil { + // return nil, err + // } + // case settlement.Equal(currency.BTC): + // offset := 0 + // for x := range pairs { + // if pairs[x].Quote.Equal(currency.USDT) { + // continue // skip USDT pairs + // } + // pairs[offset] = pairs[x] + // offset++ + // } + // pairs = pairs[:offset] + // default: + // return nil, fmt.Errorf("settlement currency %s not supported", settlement) + // } + var subscriptions subscription.List for i := range channelsToSubscribe { for j := range pairs { @@ -186,18 +136,18 @@ func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() (subscription.Lis } // DeliveryFuturesSubscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) DeliveryFuturesSubscribe(channelsToUnsubscribe subscription.List) error { - return g.handleDeliveryFuturesSubscription("subscribe", channelsToUnsubscribe) +func (g *Gateio) DeliveryFuturesSubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { + return g.handleDeliveryFuturesSubscription(ctx, conn, "subscribe", channelsToUnsubscribe) } // DeliveryFuturesUnsubscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) DeliveryFuturesUnsubscribe(channelsToUnsubscribe subscription.List) error { - return g.handleDeliveryFuturesSubscription("unsubscribe", channelsToUnsubscribe) +func (g *Gateio) DeliveryFuturesUnsubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { + return g.handleDeliveryFuturesSubscription(ctx, conn, "unsubscribe", channelsToUnsubscribe) } // handleDeliveryFuturesSubscription sends a websocket message to receive data from the channel -func (g *Gateio) handleDeliveryFuturesSubscription(event string, channelsToSubscribe subscription.List) error { - payloads, err := g.generateDeliveryFuturesPayload(event, channelsToSubscribe) +func (g *Gateio) handleDeliveryFuturesSubscription(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) error { + payloads, err := g.generateDeliveryFuturesPayload(ctx, conn, event, channelsToSubscribe) if err != nil { return err } @@ -207,9 +157,10 @@ func (g *Gateio) handleDeliveryFuturesSubscription(event string, channelsToSubsc for con, val := range payloads { for k := range val { if con == 0 { - respByte, err = g.Websocket.Conn.SendMessageReturnResponse(val[k].ID, val[k]) + respByte, err = conn.SendMessageReturnResponse(val[k].ID, val[k]) } else { - respByte, err = g.Websocket.AuthConn.SendMessageReturnResponse(val[k].ID, val[k]) + // TODO: Split into two. + respByte, err = conn.SendMessageReturnResponse(val[k].ID, val[k]) } if err != nil { errs = common.AppendError(errs, err) @@ -232,7 +183,7 @@ func (g *Gateio) handleDeliveryFuturesSubscription(event string, channelsToSubsc return errs } -func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscribe subscription.List) ([2][]WsInput, error) { +func (g *Gateio) generateDeliveryFuturesPayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([2][]WsInput, error) { payloads := [2][]WsInput{} if len(channelsToSubscribe) == 0 { return payloads, errors.New("cannot generate payload, no channels supplied") @@ -240,7 +191,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscrib var creds *account.Credentials var err error if g.Websocket.CanUseAuthenticatedEndpoints() { - creds, err = g.GetCredentials(context.TODO()) + creds, err = g.GetCredentials(ctx) if err != nil { g.Websocket.SetCanUseAuthenticatedEndpoints(false) } @@ -316,7 +267,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscrib } if strings.HasPrefix(channelsToSubscribe[i].Pairs[0].Quote.Upper().String(), "USDT") { payloads[0] = append(payloads[0], WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), + ID: conn.GenerateMessageID(false), Event: event, Channel: channelsToSubscribe[i].Channel, Payload: params, @@ -325,7 +276,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscrib }) } else { payloads[1] = append(payloads[1], WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), + ID: conn.GenerateMessageID(false), Event: event, Channel: channelsToSubscribe[i].Channel, Payload: params, diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index 8962212d0ca..b1e10de5a81 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -23,7 +23,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" - "github.com/thrasher-corp/gocryptotrader/log" ) const ( @@ -58,11 +57,8 @@ var defaultFuturesSubscriptions = []string{ futuresCandlesticksChannel, } -// responseFuturesStream a channel thought which the data coming from the two websocket connection will go through. -var responseFuturesStream = make(chan stream.Response) - // WsFuturesConnect initiates a websocket connection for futures account -func (g *Gateio) WsFuturesConnect() error { +func (g *Gateio) WsFuturesConnect(ctx context.Context, conn stream.Connection) error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { return stream.ErrWebsocketNotEnabled } @@ -75,44 +71,19 @@ func (g *Gateio) WsFuturesConnect() error { if err != nil { return err } - err = g.Websocket.Conn.Dial(&dialer, http.Header{}) - if err != nil { - return err - } - - err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ - URL: futuresWebsocketBtcURL, - RateLimit: gateioWebsocketRateLimit, - ResponseCheckTimeout: g.Config.WebsocketResponseCheckTimeout, - ResponseMaxLimit: g.Config.WebsocketResponseMaxLimit, - Authenticated: true, - }) - if err != nil { - return err - } - err = g.Websocket.AuthConn.Dial(&dialer, http.Header{}) + err = conn.DialContext(ctx, &dialer, http.Header{}) if err != nil { return err } - g.Websocket.Wg.Add(3) - go g.wsReadFuturesData() - go g.wsFunnelFuturesConnectionData(g.Websocket.Conn) - go g.wsFunnelFuturesConnectionData(g.Websocket.AuthConn) - if g.Verbose { - log.Debugf(log.ExchangeSys, "Successful connection to %v\n", - g.Websocket.GetWebsocketURL()) - } pingMessage, err := json.Marshal(WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), - Time: func() int64 { - return time.Now().Unix() - }(), + ID: conn.GenerateMessageID(false), + Time: time.Now().Unix(), // TODO: This should be a timer function. Channel: futuresPingChannel, }) if err != nil { return err } - g.Websocket.Conn.SetupPingHandler(stream.PingHandler{ + conn.SetupPingHandler(stream.PingHandler{ Websocket: true, MessageType: websocket.PingMessage, Delay: time.Second * 15, @@ -122,7 +93,7 @@ func (g *Gateio) WsFuturesConnect() error { } // GenerateFuturesDefaultSubscriptions returns default subscriptions information. -func (g *Gateio) GenerateFuturesDefaultSubscriptions() (subscription.List, error) { +func (g *Gateio) GenerateFuturesDefaultSubscriptions(settlement currency.Code) (subscription.List, error) { channelsToSubscribe := defaultFuturesSubscriptions if g.Websocket.CanUseAuthenticatedEndpoints() { channelsToSubscribe = append(channelsToSubscribe, @@ -135,6 +106,27 @@ func (g *Gateio) GenerateFuturesDefaultSubscriptions() (subscription.List, error if err != nil { return nil, err } + + switch { + case settlement.Equal(currency.USDT): + pairs, err = pairs.GetPairsByQuote(currency.USDT) + if err != nil { + return nil, err + } + case settlement.Equal(currency.BTC): + offset := 0 + for x := range pairs { + if pairs[x].Quote.Equal(currency.USDT) { + continue // skip USDT pairs + } + pairs[offset] = pairs[x] + offset++ + } + pairs = pairs[:offset] + default: + return nil, fmt.Errorf("settlement currency %s not supported", settlement) + } + subscriptions := make(subscription.List, len(channelsToSubscribe)*len(pairs)) count := 0 for i := range channelsToSubscribe { @@ -166,57 +158,17 @@ func (g *Gateio) GenerateFuturesDefaultSubscriptions() (subscription.List, error } // FuturesSubscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) FuturesSubscribe(channelsToUnsubscribe subscription.List) error { - return g.handleFuturesSubscription("subscribe", channelsToUnsubscribe) +func (g *Gateio) FuturesSubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { + return g.handleFuturesSubscription(ctx, conn, "subscribe", channelsToUnsubscribe) } // FuturesUnsubscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) FuturesUnsubscribe(channelsToUnsubscribe subscription.List) error { - return g.handleFuturesSubscription("unsubscribe", channelsToUnsubscribe) -} - -// wsReadFuturesData read coming messages thought the websocket connection and pass the data to wsHandleData for further process. -func (g *Gateio) wsReadFuturesData() { - defer g.Websocket.Wg.Done() - for { - select { - case <-g.Websocket.ShutdownC: - select { - case resp := <-responseFuturesStream: - err := g.wsHandleFuturesData(resp.Raw, asset.Futures) - if err != nil { - select { - case g.Websocket.DataHandler <- err: - default: - log.Errorf(log.WebsocketMgr, "%s websocket handle data error: %v", g.Name, err) - } - } - default: - } - return - case resp := <-responseFuturesStream: - err := g.wsHandleFuturesData(resp.Raw, asset.Futures) - if err != nil { - g.Websocket.DataHandler <- err - } - } - } -} - -// wsFunnelFuturesConnectionData receives data from multiple connection and pass the data -// to wsRead through a channel responseStream -func (g *Gateio) wsFunnelFuturesConnectionData(ws stream.Connection) { - defer g.Websocket.Wg.Done() - for { - resp := ws.ReadMessage() - if resp.Raw == nil { - return - } - responseFuturesStream <- stream.Response{Raw: resp.Raw} - } +func (g *Gateio) FuturesUnsubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { + return g.handleFuturesSubscription(ctx, conn, "unsubscribe", channelsToUnsubscribe) } -func (g *Gateio) wsHandleFuturesData(respRaw []byte, assetType asset.Item) error { +// WsHandleFuturesData handles futures websocket data +func (g *Gateio) WsHandleFuturesData(ctx context.Context, respRaw []byte, a asset.Item) error { var push WsResponse err := json.Unmarshal(respRaw, &push) if err != nil { @@ -232,27 +184,27 @@ func (g *Gateio) wsHandleFuturesData(respRaw []byte, assetType asset.Item) error switch push.Channel { case futuresTickersChannel: - return g.processFuturesTickers(respRaw, assetType) + return g.processFuturesTickers(respRaw, a) case futuresTradesChannel: - return g.processFuturesTrades(respRaw, assetType) + return g.processFuturesTrades(respRaw, a) case futuresOrderbookChannel: - return g.processFuturesOrderbookSnapshot(push.Event, push.Result, assetType, push.Time.Time()) + return g.processFuturesOrderbookSnapshot(push.Event, push.Result, a, push.Time.Time()) case futuresOrderbookTickerChannel: return g.processFuturesOrderbookTicker(push.Result) case futuresOrderbookUpdateChannel: - return g.processFuturesAndOptionsOrderbookUpdate(push.Result, assetType) + return g.processFuturesAndOptionsOrderbookUpdate(push.Result, a) case futuresCandlesticksChannel: - return g.processFuturesCandlesticks(respRaw, assetType) + return g.processFuturesCandlesticks(respRaw, a) case futuresOrdersChannel: var processed []order.Detail - processed, err = g.processFuturesOrdersPushData(respRaw, assetType) + processed, err = g.processFuturesOrdersPushData(respRaw, a) if err != nil { return err } g.Websocket.DataHandler <- processed return nil case futuresUserTradesChannel: - return g.procesFuturesUserTrades(respRaw, assetType) + return g.procesFuturesUserTrades(respRaw, a) case futuresLiquidatesChannel: return g.processFuturesLiquidatesNotification(respRaw) case futuresAutoDeleveragesChannel: @@ -260,7 +212,7 @@ func (g *Gateio) wsHandleFuturesData(respRaw []byte, assetType asset.Item) error case futuresAutoPositionCloseChannel: return g.processPositionCloseData(respRaw) case futuresBalancesChannel: - return g.processBalancePushData(respRaw, assetType) + return g.processBalancePushData(respRaw, a) case futuresReduceRiskLimitsChannel: return g.processFuturesReduceRiskLimitNotification(respRaw) case futuresPositionsChannel: @@ -276,8 +228,8 @@ func (g *Gateio) wsHandleFuturesData(respRaw []byte, assetType asset.Item) error } // handleFuturesSubscription sends a websocket message to receive data from the channel -func (g *Gateio) handleFuturesSubscription(event string, channelsToSubscribe subscription.List) error { - payloads, err := g.generateFuturesPayload(event, channelsToSubscribe) +func (g *Gateio) handleFuturesSubscription(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) error { + payloads, err := g.generateFuturesPayload(ctx, conn, event, channelsToSubscribe) if err != nil { return err } @@ -287,9 +239,10 @@ func (g *Gateio) handleFuturesSubscription(event string, channelsToSubscribe sub for con, val := range payloads { for k := range val { if con == 0 { - respByte, err = g.Websocket.Conn.SendMessageReturnResponse(val[k].ID, val[k]) + respByte, err = conn.SendMessageReturnResponse(val[k].ID, val[k]) } else { - respByte, err = g.Websocket.AuthConn.SendMessageReturnResponse(val[k].ID, val[k]) + // TODO: Authconn test temp + respByte, err = conn.SendMessageReturnResponse(val[k].ID, val[k]) } if err != nil { errs = common.AppendError(errs, err) @@ -315,7 +268,7 @@ func (g *Gateio) handleFuturesSubscription(event string, channelsToSubscribe sub return nil } -func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe subscription.List) ([2][]WsInput, error) { +func (g *Gateio) generateFuturesPayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([2][]WsInput, error) { payloads := [2][]WsInput{} if len(channelsToSubscribe) == 0 { return payloads, errors.New("cannot generate payload, no channels supplied") @@ -323,7 +276,7 @@ func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe subscr var creds *account.Credentials var err error if g.Websocket.CanUseAuthenticatedEndpoints() { - creds, err = g.GetCredentials(context.TODO()) + creds, err = g.GetCredentials(ctx) if err != nil { g.Websocket.SetCanUseAuthenticatedEndpoints(false) } @@ -401,7 +354,7 @@ func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe subscr } if strings.HasPrefix(channelsToSubscribe[i].Pairs[0].Quote.Upper().String(), "USDT") { payloads[0] = append(payloads[0], WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), + ID: conn.GenerateMessageID(false), Event: event, Channel: channelsToSubscribe[i].Channel, Payload: params, @@ -410,7 +363,7 @@ func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe subscr }) } else { payloads[1] = append(payloads[1], WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), + ID: conn.GenerateMessageID(false), Event: event, Channel: channelsToSubscribe[i].Channel, Payload: params, diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_ws_option.go index fe1384c8e52..949b98ce2f4 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_ws_option.go @@ -68,7 +68,7 @@ var defaultOptionsSubscriptions = []string{ var fetchedOptionsCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsOptionsConnect initiates a websocket connection to options websocket endpoints. -func (g *Gateio) WsOptionsConnect() error { +func (g *Gateio) WsOptionsConnect(ctx context.Context, conn stream.Connection) error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { return stream.ErrWebsocketNotEnabled } @@ -81,21 +81,19 @@ func (g *Gateio) WsOptionsConnect() error { if err != nil { return err } - err = g.Websocket.Conn.Dial(&dialer, http.Header{}) + err = conn.DialContext(ctx, &dialer, http.Header{}) if err != nil { return err } pingMessage, err := json.Marshal(WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), + ID: conn.GenerateMessageID(false), Time: time.Now().Unix(), Channel: optionsPingChannel, }) if err != nil { return err } - g.Websocket.Wg.Add(1) - go g.wsReadOptionsConnData() - g.Websocket.Conn.SetupPingHandler(stream.PingHandler{ + conn.SetupPingHandler(stream.PingHandler{ Websocket: true, Delay: time.Second * 5, MessageType: websocket.PingMessage, @@ -173,7 +171,7 @@ getEnabledPairs: return subscriptions, nil } -func (g *Gateio) generateOptionsPayload(event string, channelsToSubscribe subscription.List) ([]WsInput, error) { +func (g *Gateio) generateOptionsPayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { if len(channelsToSubscribe) == 0 { return nil, errors.New("cannot generate payload, no channels supplied") } @@ -232,7 +230,7 @@ func (g *Gateio) generateOptionsPayload(event string, channelsToSubscribe subscr } params = append([]string{strconv.FormatInt(userID, 10)}, params...) var creds *account.Credentials - creds, err = g.GetCredentials(context.Background()) + creds, err = g.GetCredentials(ctx) if err != nil { return nil, err } @@ -275,7 +273,7 @@ func (g *Gateio) generateOptionsPayload(event string, channelsToSubscribe subscr params...) } payloads[i] = WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), + ID: conn.GenerateMessageID(false), Event: event, Channel: channelsToSubscribe[i].Channel, Payload: params, @@ -286,40 +284,25 @@ func (g *Gateio) generateOptionsPayload(event string, channelsToSubscribe subscr return payloads, nil } -// wsReadOptionsConnData receives and passes on websocket messages for processing -func (g *Gateio) wsReadOptionsConnData() { - defer g.Websocket.Wg.Done() - for { - resp := g.Websocket.Conn.ReadMessage() - if resp.Raw == nil { - return - } - err := g.wsHandleOptionsData(resp.Raw) - if err != nil { - g.Websocket.DataHandler <- err - } - } -} - // OptionsSubscribe sends a websocket message to stop receiving data for asset type options -func (g *Gateio) OptionsSubscribe(channelsToUnsubscribe subscription.List) error { - return g.handleOptionsSubscription("subscribe", channelsToUnsubscribe) +func (g *Gateio) OptionsSubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { + return g.handleOptionsSubscription(ctx, conn, "subscribe", channelsToUnsubscribe) } // OptionsUnsubscribe sends a websocket message to stop receiving data for asset type options -func (g *Gateio) OptionsUnsubscribe(channelsToUnsubscribe subscription.List) error { - return g.handleOptionsSubscription("unsubscribe", channelsToUnsubscribe) +func (g *Gateio) OptionsUnsubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { + return g.handleOptionsSubscription(ctx, conn, "unsubscribe", channelsToUnsubscribe) } // handleOptionsSubscription sends a websocket message to receive data from the channel -func (g *Gateio) handleOptionsSubscription(event string, channelsToSubscribe subscription.List) error { - payloads, err := g.generateOptionsPayload(event, channelsToSubscribe) +func (g *Gateio) handleOptionsSubscription(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) error { + payloads, err := g.generateOptionsPayload(ctx, conn, event, channelsToSubscribe) if err != nil { return err } var errs error for k := range payloads { - result, err := g.Websocket.Conn.SendMessageReturnResponse(payloads[k].ID, payloads[k]) + result, err := conn.SendMessageReturnResponse(payloads[k].ID, payloads[k]) if err != nil { errs = common.AppendError(errs, err) continue @@ -345,7 +328,8 @@ func (g *Gateio) handleOptionsSubscription(event string, channelsToSubscribe sub return errs } -func (g *Gateio) wsHandleOptionsData(respRaw []byte) error { +// WsHandleOptionsData handles options websocket data +func (g *Gateio) WsHandleOptionsData(ctx context.Context, respRaw []byte) error { var push WsResponse err := json.Unmarshal(respRaw, &push) if err != nil { diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index e342c74dace..d69625ca7c2 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -1,6 +1,7 @@ package stream import ( + "context" "net/http" "time" @@ -8,11 +9,13 @@ import ( "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/order" + "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" ) // Connection defines a streaming services connection type Connection interface { Dial(*websocket.Dialer, http.Header) error + DialContext(context.Context, *websocket.Dialer, http.Header) error ReadMessage() Response SendJSONMessage(interface{}) error SetupPingHandler(PingHandler) @@ -39,6 +42,12 @@ type ConnectionSetup struct { URL string Authenticated bool ConnectionLevelReporter Reporter + Handler func(ctx context.Context, incoming []byte) error + Subscriber func(ctx context.Context, conn Connection, sub subscription.List) error + Unsubscriber func(ctx context.Context, conn Connection, unsub subscription.List) error + GenerateSubscriptions func() (subscription.List, error) + Connector func(ctx context.Context, conn Connection) error + Enabled func() bool } // PingHandler container for ping handler settings diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index badb5d565a4..073c022f36d 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -1,6 +1,7 @@ package stream import ( + "context" "errors" "fmt" "net" @@ -11,6 +12,7 @@ import ( "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/config" + "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" @@ -35,7 +37,6 @@ var ( var ( errAlreadyRunning = errors.New("connection monitor is already running") errExchangeConfigIsNil = errors.New("exchange config is nil") - errExchangeConfigEmpty = errors.New("exchange config is empty") errWebsocketIsNil = errors.New("websocket is nil") errWebsocketSetupIsNil = errors.New("websocket setup is nil") errWebsocketAlreadyInitialised = errors.New("websocket already initialised") @@ -51,6 +52,7 @@ var ( errWebsocketSubscriberUnset = errors.New("websocket subscriber function needs to be set") errWebsocketUnsubscriberUnset = errors.New("websocket unsubscriber functionality allowed but unsubscriber function not set") errWebsocketConnectorUnset = errors.New("websocket connector function not set") + errWebsocketDataHandlerUnset = errors.New("websocket data handler not set") errReadMessageErrorsNil = errors.New("read message errors is nil") errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set") errClosedConnection = errors.New("use of closed network connection") @@ -62,6 +64,7 @@ var ( errCannotShutdown = errors.New("websocket cannot shutdown") errAlreadyReconnecting = errors.New("websocket in the process of reconnection") errConnSetup = errors.New("error in connection setup") + errNoPendingConnections = errors.New("no pending connections, call SetupNewConnection first") ) var ( @@ -127,29 +130,15 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { } w.setEnabled(s.ExchangeConfig.Features.Enabled.Websocket) - if s.Connector == nil { - return fmt.Errorf("%s %w", w.exchangeName, errWebsocketConnectorUnset) - } w.connector = s.Connector - - if s.Subscriber == nil { - return fmt.Errorf("%s %w", w.exchangeName, errWebsocketSubscriberUnset) - } w.Subscriber = s.Subscriber + w.Unsubscriber = s.Unsubscriber + w.GenerateSubs = s.GenerateSubscriptions - if w.features.Unsubscribe && s.Unsubscriber == nil { - return fmt.Errorf("%s %w", w.exchangeName, errWebsocketUnsubscriberUnset) - } w.connectionMonitorDelay = s.ExchangeConfig.ConnectionMonitorDelay if w.connectionMonitorDelay <= 0 { w.connectionMonitorDelay = config.DefaultConnectionMonitorDelay } - w.Unsubscriber = s.Unsubscriber - - if s.GenerateSubscriptions == nil { - return fmt.Errorf("%s %w", w.exchangeName, errWebsocketSubscriptionsGeneratorUnset) - } - w.GenerateSubs = s.GenerateSubscriptions if s.DefaultURL == "" { return fmt.Errorf("%s websocket %w", w.exchangeName, errDefaultURLIsEmpty) @@ -202,36 +191,63 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { if w == nil { return fmt.Errorf("%w: %w", errConnSetup, errWebsocketIsNil) } - if c == (ConnectionSetup{}) { - return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigEmpty) - } - if w.exchangeName == "" { return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigNameEmpty) } - if w.TrafficAlert == nil { return fmt.Errorf("%w: %w", errConnSetup, errTrafficAlertNil) } - if w.ReadMessageErrors == nil { return fmt.Errorf("%w: %w", errConnSetup, errReadMessageErrorsNil) } - - connectionURL := w.GetWebsocketURL() - if c.URL != "" { - connectionURL = c.URL - } - if c.ConnectionLevelReporter == nil { c.ConnectionLevelReporter = w.ExchangeLevelReporter } - if c.ConnectionLevelReporter == nil { c.ConnectionLevelReporter = globalReporter } - newConn := &WebsocketConnection{ + // If connector is nil, we assume that the connection and supporting + // functions are defined per connection. Else we use the global connector + // and supporting functions for backwards compatibility. + if w.connector == nil { + fmt.Println("w.connector == nil") + if c.Handler == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketDataHandlerUnset) + } + if c.Subscriber == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriberUnset) + } + if c.Unsubscriber == nil && w.features.Unsubscribe { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketUnsubscriberUnset) + } + if c.GenerateSubscriptions == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriptionsGeneratorUnset) + } + if c.Connector == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketConnectorUnset) + } + w.PendingConnections = append(w.PendingConnections, c) + return nil + } + + if c.Authenticated { + w.AuthConn = w.getConnectionFromSetup(c) + } else { + w.Conn = w.getConnectionFromSetup(c) + } + + return nil +} + +// getConnectionFromSetup returns a websocket connection from a setup +// configuration. This is used for setting up new connections on the fly. +func (w *Websocket) getConnectionFromSetup(c ConnectionSetup) *WebsocketConnection { + connectionURL := w.GetWebsocketURL() + if c.URL != "" { + connectionURL = c.URL + } + return &WebsocketConnection{ ExchangeName: w.exchangeName, URL: connectionURL, ProxyURL: w.GetProxyAddress(), @@ -245,22 +261,11 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { RateLimit: c.RateLimit, Reporter: c.ConnectionLevelReporter, } - - if c.Authenticated { - w.AuthConn = newConn - } else { - w.Conn = newConn - } - - return nil } // Connect initiates a websocket connection by using a package defined connection // function func (w *Websocket) Connect() error { - if w.connector == nil { - return errNoConnectFunc - } w.m.Lock() defer w.m.Unlock() @@ -283,27 +288,86 @@ func (w *Websocket) Connect() error { w.trafficMonitor() w.setState(connectingState) - err := w.connector() - if err != nil { - w.setState(disconnectedState) - return fmt.Errorf("%v Error connecting %w", w.exchangeName, err) - } - w.setState(connectedState) + if w.connector != nil { + fmt.Println("OLD CONNECTOR") + err := w.connector() + if err != nil { + w.setState(disconnectedState) + return fmt.Errorf("%v Error connecting %w", w.exchangeName, err) + } + w.setState(connectedState) - if !w.IsConnectionMonitorRunning() { - err = w.connectionMonitor() + if !w.IsConnectionMonitorRunning() { + err := w.connectionMonitor() + if err != nil { + log.Errorf(log.WebsocketMgr, "%s cannot start websocket connection monitor %v", w.GetName(), err) + } + } + + subs, err := w.GenerateSubs() // regenerate state on new connection if err != nil { - log.Errorf(log.WebsocketMgr, "%s cannot start websocket connection monitor %v", w.GetName(), err) + return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) + } + if len(subs) != 0 { + if err := w.SubscribeToChannels(subs); err != nil { + return err + } } + return nil } - subs, err := w.GenerateSubs() // regenerate state on new connection - if err != nil { - return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) + fmt.Println("NEW CONNECTOR") + + if len(w.PendingConnections) == 0 { + return fmt.Errorf("cannot connect: %w", errNoPendingConnections) } - if len(subs) != 0 { - if err := w.SubscribeToChannels(subs); err != nil { - return err + + for i := range w.PendingConnections { + fmt.Println("SPAWN CONNECTION: ", i) + if !w.PendingConnections[i].Enabled() { + fmt.Println("Connection not enabled") + continue + } + subs, err := w.PendingConnections[i].GenerateSubscriptions() // regenerate state on new connection + if err != nil { + if errors.Is(err, asset.ErrNotEnabled) { + log.Warnf(log.WebsocketMgr, "%s websocket: %v", w.exchangeName, err) + continue // Non-fatal error, we can continue to the next connection + } + w.setState(disconnectedState) + return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) + } + + fmt.Println("subs: ", len(subs)) + + if len(subs) == 0 { + // If no subscriptions are generated, we skip the connection + log.Warnf(log.WebsocketMgr, "%s websocket: no subscriptions generated", w.exchangeName) + continue + } + + // TODO: Add window to max subscriptions per connection, to spawn new connections if needed. + conn := w.getConnectionFromSetup(w.PendingConnections[i]) + err = w.PendingConnections[i].Connector(context.TODO(), conn) + if err != nil { + return fmt.Errorf("%v Error connecting %w", w.exchangeName, err) + } + w.Wg.Add(1) + go w.Reader(context.TODO(), conn, w.PendingConnections[i].Handler) + + fmt.Println("Subscribing to channels: ", len(subs)) + err = w.PendingConnections[i].Subscriber(context.TODO(), conn, subs) + if err != nil { + return fmt.Errorf("%v Error subscribing %w", w.exchangeName, err) + } + } + + fmt.Println("DONE SPAWNING CONNECTIONS") + + if !w.IsConnectionMonitorRunning() { + err := w.connectionMonitor() + if err != nil { + log.Errorf(log.WebsocketMgr, "%s cannot start websocket connection monitor %v", w.GetName(), err) } } @@ -962,3 +1026,17 @@ func (w *Websocket) checkSubscriptions(subs subscription.List) error { return nil } + +// Reader reads and handles data from a specific connection +func (w *Websocket) Reader(ctx context.Context, conn Connection, handler func(ctx context.Context, message []byte) error) { + defer w.Wg.Done() + for { + resp := conn.ReadMessage() + if resp.Raw == nil { + return // Connection has been closed + } + if err := handler(ctx, resp.Raw); err != nil { + w.ReadMessageErrors <- err + } + } +} diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 6a00d01ab74..ebc5d340d03 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -4,6 +4,7 @@ import ( "bytes" "compress/flate" "compress/gzip" + "context" "crypto/rand" "encoding/json" "fmt" @@ -66,7 +67,6 @@ func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header var err error var conStatus *http.Response - w.Connection, conStatus, err = dialer.Dial(w.URL, headers) if err != nil { if conStatus != nil { @@ -74,7 +74,39 @@ func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header } return fmt.Errorf("%s websocket connection: %v Error: %w", w.ExchangeName, w.URL, err) } - defer conStatus.Body.Close() + defer conStatus.Body.Close() // TODO: Close on error above. This is a potential resource leak. + + if w.Verbose { + log.Infof(log.WebsocketMgr, "%v Websocket connected to %s\n", w.ExchangeName, w.URL) + } + select { + case w.Traffic <- struct{}{}: + default: + } + w.setConnectedStatus(true) + return nil +} + +// DialContext sets proxy urls and then connects to the websocket +func (w *WebsocketConnection) DialContext(ctx context.Context, dialer *websocket.Dialer, headers http.Header) error { + if w.ProxyURL != "" { + proxy, err := url.Parse(w.ProxyURL) + if err != nil { + return err + } + dialer.Proxy = http.ProxyURL(proxy) + } + + var err error + var conStatus *http.Response + w.Connection, conStatus, err = dialer.DialContext(ctx, w.URL, headers) + if err != nil { + if conStatus != nil { + return fmt.Errorf("%s websocket connection: %v %v %v Error: %w", w.ExchangeName, w.URL, conStatus, conStatus.StatusCode, err) + } + return fmt.Errorf("%s websocket connection: %v Error: %w", w.ExchangeName, w.URL, err) + } + defer conStatus.Body.Close() // TODO: Close on error above. This is a potential resource leak. if w.Verbose { log.Infof(log.WebsocketMgr, "%v Websocket connected to %s\n", w.ExchangeName, w.URL) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 2d44de95097..6e7fa1b888f 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -1121,9 +1121,6 @@ func TestSetupNewConnection(t *testing.T) { err = web.Setup(defaultSetup) assert.NoError(t, err, "Setup should not error") - err = web.SetupNewConnection(ConnectionSetup{}) - assert.ErrorIs(t, err, errExchangeConfigEmpty, "SetupNewConnection should error correctly") - err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) assert.NoError(t, err, "SetupNewConnection should not error") diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 5d0009a4642..66473d5721c 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -50,6 +50,8 @@ type Websocket struct { m sync.Mutex connector func() error + PendingConnections []ConnectionSetup + subscriptions *subscription.Store // Subscriber function for exchange specific subscribe implementation From f509399a9714f088c59efe5fb4b26a88d3437aa4 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Mon, 15 Jul 2024 09:55:45 +1000 Subject: [PATCH 002/138] meow --- exchanges/gateio/gateio_test.go | 76 +++++++-------- exchanges/gateio/gateio_wrapper.go | 147 ++++++++++++++--------------- exchanges/stream/stream_types.go | 1 - exchanges/stream/websocket.go | 20 ++-- 4 files changed, 123 insertions(+), 121 deletions(-) diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index ea9bd9d82aa..c6df7abbdf5 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -2654,7 +2654,7 @@ const wsFuturesTickerPushDataJSON = `{"time": 1541659086, "channel": "futures.ti func TestFuturesTicker(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesTickerPushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesTickerPushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket push data error: %v", g.Name, err) } } @@ -2663,7 +2663,7 @@ const wsFuturesTradesPushDataJSON = `{"channel": "futures.trades","event": "upda func TestFuturesTrades(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesTradesPushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesTradesPushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket push data error: %v", g.Name, err) } } @@ -2674,7 +2674,7 @@ const ( func TestOrderbookData(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesOrderbookTickerJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesOrderbookTickerJSON), asset.Futures); err != nil { t.Errorf("%s websocket orderbook ticker push data error: %v", g.Name, err) } } @@ -2683,7 +2683,7 @@ const wsFuturesOrderPushDataJSON = `{ "channel": "futures.orders", "event": "upd func TestFuturesOrderPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesOrderPushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesOrderPushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket futures order push data error: %v", g.Name, err) } } @@ -2692,7 +2692,7 @@ const wsFuturesUsertradesPushDataJSON = `{"time": 1543205083, "channel": "future func TestFuturesUserTrades(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesUsertradesPushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesUsertradesPushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket futures user trades push data error: %v", g.Name, err) } } @@ -2701,7 +2701,7 @@ const wsFuturesLiquidationPushDataJSON = `{"channel": "futures.liquidates", "eve func TestFuturesLiquidationPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesLiquidationPushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesLiquidationPushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket futures liquidation push data error: %v", g.Name, err) } } @@ -2710,7 +2710,7 @@ const wsFuturesAutoDelevergesNotification = `{"channel": "futures.auto_deleverag func TestFuturesAutoDeleverges(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesAutoDelevergesNotification), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesAutoDelevergesNotification), asset.Futures); err != nil { t.Errorf("%s websocket futures auto deleverge push data error: %v", g.Name, err) } } @@ -2719,7 +2719,7 @@ const wsFuturesPositionClosePushDataJSON = ` {"channel": "futures.position_close func TestPositionClosePushData(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesPositionClosePushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesPositionClosePushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket futures position close push data error: %v", g.Name, err) } } @@ -2728,7 +2728,7 @@ const wsFuturesBalanceNotificationPushDataJSON = `{"channel": "futures.balances" func TestFuturesBalanceNotification(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesBalanceNotificationPushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesBalanceNotificationPushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket futures balance notification push data error: %v", g.Name, err) } } @@ -2737,7 +2737,7 @@ const wsFuturesReduceRiskLimitNotificationPushDataJSON = `{"time": 1551858330, " func TestFuturesReduceRiskLimitPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesReduceRiskLimitNotificationPushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesReduceRiskLimitNotificationPushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket futures reduce risk limit notification push data error: %v", g.Name, err) } } @@ -2746,7 +2746,7 @@ const wsFuturesPositionsNotificationPushDataJSON = `{"time": 1588212926,"channel func TestFuturesPositionsNotification(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesPositionsNotificationPushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesPositionsNotificationPushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket futures positions change notification push data error: %v", g.Name, err) } } @@ -2755,7 +2755,7 @@ const wsFuturesAutoOrdersPushDataJSON = `{"time": 1596798126,"channel": "futures func TestFuturesAutoOrderPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesAutoOrdersPushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesAutoOrdersPushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket futures auto orders push data error: %v", g.Name, err) } } @@ -2766,7 +2766,7 @@ const optionsContractTickerPushDataJSON = `{"time": 1630576352, "channel": "opti func TestOptionsContractTickerPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsContractTickerPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsContractTickerPushDataJSON)); err != nil { t.Errorf("%s websocket options contract ticker push data failed with error %v", g.Name, err) } } @@ -2775,7 +2775,7 @@ const optionsUnderlyingTickerPushDataJSON = `{"time": 1630576352, "channel": "op func TestOptionsUnderlyingTickerPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsUnderlyingTickerPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsUnderlyingTickerPushDataJSON)); err != nil { t.Errorf("%s websocket options underlying ticker push data error: %v", g.Name, err) } } @@ -2784,7 +2784,7 @@ const optionsContractTradesPushDataJSON = `{"time": 1630576356, "channel": "opti func TestOptionsContractTradesPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsContractTradesPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsContractTradesPushDataJSON)); err != nil { t.Errorf("%s websocket contract trades push data error: %v", g.Name, err) } } @@ -2793,7 +2793,7 @@ const optionsUnderlyingTradesPushDataJSON = `{"time": 1630576356, "channel": "op func TestOptionsUnderlyingTradesPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsUnderlyingTradesPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsUnderlyingTradesPushDataJSON)); err != nil { t.Errorf("%s websocket underlying trades push data error: %v", g.Name, err) } } @@ -2802,7 +2802,7 @@ const optionsUnderlyingPricePushDataJSON = `{ "time": 1630576356, "channel": "op func TestOptionsUnderlyingPricePushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsUnderlyingPricePushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsUnderlyingPricePushDataJSON)); err != nil { t.Errorf("%s websocket underlying price push data error: %v", g.Name, err) } } @@ -2811,7 +2811,7 @@ const optionsMarkPricePushDataJSON = `{ "time": 1630576356, "channel": "options. func TestOptionsMarkPricePushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsMarkPricePushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsMarkPricePushDataJSON)); err != nil { t.Errorf("%s websocket mark price push data error: %v", g.Name, err) } } @@ -2820,7 +2820,7 @@ const optionsSettlementsPushDataJSON = `{ "time": 1630576356, "channel": "option func TestSettlementsPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsSettlementsPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsSettlementsPushDataJSON)); err != nil { t.Errorf("%s websocket options settlements push data error: %v", g.Name, err) } } @@ -2829,7 +2829,7 @@ const optionsContractPushDataJSON = `{"time": 1630576356, "channel": "options.co func TestOptionsContractPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsContractPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsContractPushDataJSON)); err != nil { t.Errorf("%s websocket options contracts push data error: %v", g.Name, err) } } @@ -2841,10 +2841,10 @@ const ( func TestOptionsCandlesticksPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsContractCandlesticksPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsContractCandlesticksPushDataJSON)); err != nil { t.Errorf("%s websocket options contracts candlestick push data error: %v", g.Name, err) } - if err := g.wsHandleOptionsData([]byte(optionsUnderlyingCandlesticksPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsUnderlyingCandlesticksPushDataJSON)); err != nil { t.Errorf("%s websocket options underlying candlestick push data error: %v", g.Name, err) } } @@ -2858,17 +2858,17 @@ const ( func TestOptionsOrderbookPushData(t *testing.T) { t.Parallel() - err := g.wsHandleOptionsData([]byte(optionsOrderbookTickerPushDataJSON)) + err := g.WsHandleOptionsData(context.Background(), []byte(optionsOrderbookTickerPushDataJSON)) if err != nil { t.Errorf("%s websocket options orderbook ticker push data error: %v", g.Name, err) } - if err = g.wsHandleOptionsData([]byte(optionsOrderbookSnapshotPushDataJSON)); err != nil { + if err = g.WsHandleOptionsData(context.Background(), []byte(optionsOrderbookSnapshotPushDataJSON)); err != nil { t.Errorf("%s websocket options orderbook snapshot push data error: %v", g.Name, err) } - if err = g.wsHandleOptionsData([]byte(optionsOrderbookUpdatePushDataJSON)); err != nil { + if err = g.WsHandleOptionsData(context.Background(), []byte(optionsOrderbookUpdatePushDataJSON)); err != nil { t.Errorf("%s websocket options orderbook update push data error: %v", g.Name, err) } - if err = g.wsHandleOptionsData([]byte(optionsOrderbookSnapshotUpdateEventPushDataJSON)); err != nil { + if err = g.WsHandleOptionsData(context.Background(), []byte(optionsOrderbookSnapshotUpdateEventPushDataJSON)); err != nil { t.Errorf("%s websocket options orderbook snapshot update event push data error: %v", g.Name, err) } } @@ -2877,7 +2877,7 @@ const optionsOrderPushDataJSON = `{"time": 1630654851,"channel": "options.orders func TestOptionsOrderPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsOrderPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsOrderPushDataJSON)); err != nil { t.Errorf("%s websocket options orders push data error: %v", g.Name, err) } } @@ -2886,7 +2886,7 @@ const optionsUsersTradesPushDataJSON = `{ "time": 1639144214, "channel": "option func TestOptionUserTradesPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsUsersTradesPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsUsersTradesPushDataJSON)); err != nil { t.Errorf("%s websocket options orders push data error: %v", g.Name, err) } } @@ -2895,7 +2895,7 @@ const optionsLiquidatesPushDataJSON = `{ "channel": "options.liquidates", "event func TestOptionsLiquidatesPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsLiquidatesPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsLiquidatesPushDataJSON)); err != nil { t.Errorf("%s websocket options liquidates push data error: %v", g.Name, err) } } @@ -2904,7 +2904,7 @@ const optionsSettlementPushDataJSON = `{ "channel": "options.user_settlements", func TestOptionsSettlementPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsSettlementPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsSettlementPushDataJSON)); err != nil { t.Errorf("%s websocket options settlement push data error: %v", g.Name, err) } } @@ -2913,7 +2913,7 @@ const optionsPositionClosePushDataJSON = `{"channel": "options.position_closes", func TestOptionsPositionClosePushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsPositionClosePushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsPositionClosePushDataJSON)); err != nil { t.Errorf("%s websocket options position close push data error: %v", g.Name, err) } } @@ -2922,7 +2922,7 @@ const optionsBalancePushDataJSON = `{ "channel": "options.balances", "event": "u func TestOptionsBalancePushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsBalancePushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsBalancePushDataJSON)); err != nil { t.Errorf("%s websocket options balance push data error: %v", g.Name, err) } } @@ -2931,7 +2931,7 @@ const optionsPositionPushDataJSON = `{"time": 1630654851, "channel": "options.po func TestOptionsPositionPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsPositionPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsPositionPushDataJSON)); err != nil { t.Errorf("%s websocket options position push data error: %v", g.Name, err) } } @@ -2943,11 +2943,11 @@ const ( func TestFuturesOrderbookPushData(t *testing.T) { t.Parallel() - err := g.wsHandleFuturesData([]byte(futuresOrderbookPushData), asset.Futures) + err := g.WsHandleFuturesData(context.Background(), []byte(futuresOrderbookPushData), asset.Futures) if err != nil { t.Error(err) } - err = g.wsHandleFuturesData([]byte(futuresOrderbookUpdatePushData), asset.Futures) + err = g.WsHandleFuturesData(context.Background(), []byte(futuresOrderbookUpdatePushData), asset.Futures) if err != nil { t.Error(err) } @@ -2957,7 +2957,7 @@ const futuresCandlesticksPushData = `{"time": 1678469467, "time_ms": 16784694679 func TestFuturesCandlestickPushData(t *testing.T) { t.Parallel() - err := g.wsHandleFuturesData([]byte(futuresCandlesticksPushData), asset.Futures) + err := g.WsHandleFuturesData(context.Background(), []byte(futuresCandlesticksPushData), asset.Futures) if err != nil { t.Error(err) } @@ -2971,13 +2971,13 @@ func TestGenerateDefaultSubscriptionsSpot(t *testing.T) { } func TestGenerateDeliveryFuturesDefaultSubscriptions(t *testing.T) { t.Parallel() - if _, err := g.GenerateDeliveryFuturesDefaultSubscriptions(); err != nil { + if _, err := g.GenerateDeliveryFuturesDefaultSubscriptions(currency.USDT); err != nil { t.Error(err) } } func TestGenerateFuturesDefaultSubscriptions(t *testing.T) { t.Parallel() - if _, err := g.GenerateFuturesDefaultSubscriptions(); err != nil { + if _, err := g.GenerateFuturesDefaultSubscriptions(currency.USDT); err != nil { t.Error(err) } } diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 92f03ef74c6..8c5b1c7a905 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -27,7 +27,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" - "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" "github.com/thrasher-corp/gocryptotrader/log" @@ -230,89 +229,85 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Unsubscriber: g.SpotUnsubscribe, GenerateSubscriptions: g.GenerateDefaultSubscriptionsSpot, Connector: g.WsConnectSpot, - Enabled: func() bool { return g.CheckWebsocketEnabled(asset.Spot) }, - }) - if err != nil { - return err - } - // Futures connection - USDT margined - err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ - URL: futuresWebsocketUsdtURL, - RateLimit: gateioWebsocketRateLimit, - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - Handler: func(ctx context.Context, incoming []byte) error { - return g.WsHandleFuturesData(ctx, incoming, asset.Futures) - }, - Subscriber: g.FuturesSubscribe, - Unsubscriber: g.FuturesUnsubscribe, - GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.USDT) }, - Connector: g.WsFuturesConnect, - Enabled: func() bool { return g.CheckWebsocketEnabled(asset.Futures) }, - }) - if err != nil { - return err - } - - // Futures connection - BTC margined - err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ - URL: futuresWebsocketBtcURL, - RateLimit: gateioWebsocketRateLimit, - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - Handler: func(ctx context.Context, incoming []byte) error { - return g.WsHandleFuturesData(ctx, incoming, asset.Futures) - }, - Subscriber: g.FuturesSubscribe, - Unsubscriber: g.FuturesUnsubscribe, - GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.BTC) }, - Connector: g.WsFuturesConnect, - Enabled: func() bool { return g.CheckWebsocketEnabled(asset.Futures) }, + // Enabled: func() bool { return g.CheckWebsocketEnabled(asset.Spot) }, }) if err != nil { return err } + // // Futures connection - USDT margined + // err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + // URL: futuresWebsocketUsdtURL, + // RateLimit: gateioWebsocketRateLimit, + // ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + // ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + // Handler: func(ctx context.Context, incoming []byte) error { + // return g.WsHandleFuturesData(ctx, incoming, asset.Futures) + // }, + // Subscriber: g.FuturesSubscribe, + // Unsubscriber: g.FuturesUnsubscribe, + // GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.USDT) }, + // Connector: g.WsFuturesConnect, + // // Enabled: func() bool { return g.CheckWebsocketEnabled(asset.Futures) }, + // }) + // if err != nil { + // return err + // } - // Futures connection - Delivery - USDT margined - err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ - URL: deliveryRealUSDTTradingURL, - RateLimit: gateioWebsocketRateLimit, - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - Handler: func(ctx context.Context, incoming []byte) error { - return g.WsHandleFuturesData(ctx, incoming, asset.DeliveryFutures) - }, - Subscriber: g.DeliveryFuturesSubscribe, - Unsubscriber: g.DeliveryFuturesUnsubscribe, - GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateDeliveryFuturesDefaultSubscriptions(currency.BTC) }, - Connector: g.WsDeliveryFuturesConnect, - Enabled: func() bool { return true }, - }) - if err != nil { - return err - } + // // Futures connection - BTC margined + // err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + // URL: futuresWebsocketBtcURL, + // RateLimit: gateioWebsocketRateLimit, + // ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + // ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + // Handler: func(ctx context.Context, incoming []byte) error { + // return g.WsHandleFuturesData(ctx, incoming, asset.Futures) + // }, + // Subscriber: g.FuturesSubscribe, + // Unsubscriber: g.FuturesUnsubscribe, + // GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.BTC) }, + // Connector: g.WsFuturesConnect, + // // Enabled: func() bool { return g.CheckWebsocketEnabled(asset.Futures) }, + // }) + // if err != nil { + // return err + // } - // TODO: Add BTC margined delivery futures. + // // Futures connection - Delivery - USDT margined + // err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + // URL: deliveryRealUSDTTradingURL, + // RateLimit: gateioWebsocketRateLimit, + // ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + // ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + // Handler: func(ctx context.Context, incoming []byte) error { + // return g.WsHandleFuturesData(ctx, incoming, asset.DeliveryFutures) + // }, + // Subscriber: g.DeliveryFuturesSubscribe, + // Unsubscriber: g.DeliveryFuturesUnsubscribe, + // GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateDeliveryFuturesDefaultSubscriptions(currency.BTC) }, + // Connector: g.WsDeliveryFuturesConnect, + // // Enabled: func() bool { return true }, + // }) + // if err != nil { + // return err + // } - // Futures connection - Options - return g.Websocket.SetupNewConnection(stream.ConnectionSetup{ - URL: optionsWebsocketURL, - RateLimit: gateioWebsocketRateLimit, - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - Handler: g.WsHandleOptionsData, - Subscriber: g.OptionsSubscribe, - Unsubscriber: g.OptionsUnsubscribe, - GenerateSubscriptions: g.GenerateOptionsDefaultSubscriptions, - Connector: g.WsOptionsConnect, - Enabled: func() bool { return true }, - }) -} + // // TODO: Add BTC margined delivery futures. + + // // Futures connection - Options + // return g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + // URL: optionsWebsocketURL, + // RateLimit: gateioWebsocketRateLimit, + // ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + // ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + // Handler: g.WsHandleOptionsData, + // Subscriber: g.OptionsSubscribe, + // Unsubscriber: g.OptionsUnsubscribe, + // GenerateSubscriptions: g.GenerateOptionsDefaultSubscriptions, + // Connector: g.WsOptionsConnect, + // // Enabled: func() bool { return true }, + // }) -// CheckWebsocketEnabled checks if the websocket is enabled for an individual asset -func (g *Gateio) CheckWebsocketEnabled(a asset.Item) bool { - err := g.CurrencyPairs.IsAssetEnabled(asset.Futures) - return err == nil && g.AssetWebsocketSupport.IsAssetWebsocketSupported(a) + return nil } // UpdateTicker updates and returns the ticker for a currency pair diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index d69625ca7c2..060dbb46f96 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -47,7 +47,6 @@ type ConnectionSetup struct { Unsubscriber func(ctx context.Context, conn Connection, unsub subscription.List) error GenerateSubscriptions func() (subscription.List, error) Connector func(ctx context.Context, conn Connection) error - Enabled func() bool } // PingHandler container for ping handler settings diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 073c022f36d..2b0c765ded1 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/url" + "os" "slices" "time" @@ -289,7 +290,6 @@ func (w *Websocket) Connect() error { w.setState(connectingState) if w.connector != nil { - fmt.Println("OLD CONNECTOR") err := w.connector() if err != nil { w.setState(disconnectedState) @@ -324,14 +324,16 @@ func (w *Websocket) Connect() error { for i := range w.PendingConnections { fmt.Println("SPAWN CONNECTION: ", i) - if !w.PendingConnections[i].Enabled() { - fmt.Println("Connection not enabled") - continue - } + // if !w.PendingConnections[i].Enabled() { + // fmt.Println("Connection not enabled") + // continue + // } subs, err := w.PendingConnections[i].GenerateSubscriptions() // regenerate state on new connection if err != nil { if errors.Is(err, asset.ErrNotEnabled) { - log.Warnf(log.WebsocketMgr, "%s websocket: %v", w.exchangeName, err) + if w.verbose { + log.Warnf(log.WebsocketMgr, "%s websocket: %v", w.exchangeName, err) + } continue // Non-fatal error, we can continue to the next connection } w.setState(disconnectedState) @@ -340,6 +342,12 @@ func (w *Websocket) Connect() error { fmt.Println("subs: ", len(subs)) + for x := range subs { + fmt.Println("SUBS: ", subs[x]) + } + + os.Exit(1) + if len(subs) == 0 { // If no subscriptions are generated, we skip the connection log.Warnf(log.WebsocketMgr, "%s websocket: no subscriptions generated", w.exchangeName) From feed04e10003ece5e74de2ccbe23b732afe536e7 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Mon, 15 Jul 2024 12:16:30 +1000 Subject: [PATCH 003/138] Add tests and shenanigans --- exchanges/gateio/gateio_test.go | 2 +- exchanges/gateio/gateio_websocket.go | 2 + exchanges/gateio/gateio_wrapper.go | 154 ++++++++---------- .../gateio/gateio_ws_delivery_futures.go | 24 +-- exchanges/gateio/gateio_ws_futures.go | 1 + exchanges/stream/websocket.go | 56 ++++--- exchanges/stream/websocket_connection.go | 3 + exchanges/stream/websocket_test.go | 81 ++++++--- 8 files changed, 174 insertions(+), 149 deletions(-) diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index c6df7abbdf5..4fa75de145b 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -2971,7 +2971,7 @@ func TestGenerateDefaultSubscriptionsSpot(t *testing.T) { } func TestGenerateDeliveryFuturesDefaultSubscriptions(t *testing.T) { t.Parallel() - if _, err := g.GenerateDeliveryFuturesDefaultSubscriptions(currency.USDT); err != nil { + if _, err := g.GenerateDeliveryFuturesDefaultSubscriptions(); err != nil { t.Error(err) } } diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index cefd40931b4..b68b85ee60d 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -635,6 +635,8 @@ func (g *Gateio) GenerateDefaultSubscriptionsSpot() (subscription.List, error) { assetType = asset.CrossMargin pairs, err = g.GetEnabledPairs(asset.CrossMargin) default: + // TODO: Check and add balance support as spot balances can be + // subscribed without a currency pair supplied. assetType = asset.Spot pairs, err = g.GetEnabledPairs(asset.Spot) } diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 8c5b1c7a905..030a3b3d466 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -27,6 +27,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" + "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" "github.com/thrasher-corp/gocryptotrader/log" @@ -151,6 +152,7 @@ func (g *Gateio) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } + // TODO: Add websocket margin and cross margin support. err = g.DisableAssetWebsocketSupport(asset.Margin) if err != nil { log.Errorln(log.ExchangeSys, err) @@ -159,18 +161,6 @@ func (g *Gateio) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - // err = g.DisableAssetWebsocketSupport(asset.Futures) - // if err != nil { - // log.Errorln(log.ExchangeSys, err) - // } - // err = g.DisableAssetWebsocketSupport(asset.DeliveryFutures) - // if err != nil { - // log.Errorln(log.ExchangeSys, err) - // } - // err = g.DisableAssetWebsocketSupport(asset.Options) - // if err != nil { - // log.Errorln(log.ExchangeSys, err) - // } g.API.Endpoints = g.NewEndpoints() err = g.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: gateioTradeURL, @@ -229,85 +219,77 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Unsubscriber: g.SpotUnsubscribe, GenerateSubscriptions: g.GenerateDefaultSubscriptionsSpot, Connector: g.WsConnectSpot, - // Enabled: func() bool { return g.CheckWebsocketEnabled(asset.Spot) }, }) if err != nil { return err } - // // Futures connection - USDT margined - // err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ - // URL: futuresWebsocketUsdtURL, - // RateLimit: gateioWebsocketRateLimit, - // ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - // ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - // Handler: func(ctx context.Context, incoming []byte) error { - // return g.WsHandleFuturesData(ctx, incoming, asset.Futures) - // }, - // Subscriber: g.FuturesSubscribe, - // Unsubscriber: g.FuturesUnsubscribe, - // GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.USDT) }, - // Connector: g.WsFuturesConnect, - // // Enabled: func() bool { return g.CheckWebsocketEnabled(asset.Futures) }, - // }) - // if err != nil { - // return err - // } - - // // Futures connection - BTC margined - // err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ - // URL: futuresWebsocketBtcURL, - // RateLimit: gateioWebsocketRateLimit, - // ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - // ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - // Handler: func(ctx context.Context, incoming []byte) error { - // return g.WsHandleFuturesData(ctx, incoming, asset.Futures) - // }, - // Subscriber: g.FuturesSubscribe, - // Unsubscriber: g.FuturesUnsubscribe, - // GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.BTC) }, - // Connector: g.WsFuturesConnect, - // // Enabled: func() bool { return g.CheckWebsocketEnabled(asset.Futures) }, - // }) - // if err != nil { - // return err - // } - - // // Futures connection - Delivery - USDT margined - // err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ - // URL: deliveryRealUSDTTradingURL, - // RateLimit: gateioWebsocketRateLimit, - // ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - // ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - // Handler: func(ctx context.Context, incoming []byte) error { - // return g.WsHandleFuturesData(ctx, incoming, asset.DeliveryFutures) - // }, - // Subscriber: g.DeliveryFuturesSubscribe, - // Unsubscriber: g.DeliveryFuturesUnsubscribe, - // GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateDeliveryFuturesDefaultSubscriptions(currency.BTC) }, - // Connector: g.WsDeliveryFuturesConnect, - // // Enabled: func() bool { return true }, - // }) - // if err != nil { - // return err - // } - - // // TODO: Add BTC margined delivery futures. - - // // Futures connection - Options - // return g.Websocket.SetupNewConnection(stream.ConnectionSetup{ - // URL: optionsWebsocketURL, - // RateLimit: gateioWebsocketRateLimit, - // ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - // ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - // Handler: g.WsHandleOptionsData, - // Subscriber: g.OptionsSubscribe, - // Unsubscriber: g.OptionsUnsubscribe, - // GenerateSubscriptions: g.GenerateOptionsDefaultSubscriptions, - // Connector: g.WsOptionsConnect, - // // Enabled: func() bool { return true }, - // }) + // Futures connection - USDT margined + err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + URL: futuresWebsocketUsdtURL, + RateLimit: gateioWebsocketRateLimit, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: func(ctx context.Context, incoming []byte) error { + return g.WsHandleFuturesData(ctx, incoming, asset.Futures) + }, + Subscriber: g.FuturesSubscribe, + Unsubscriber: g.FuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.USDT) }, + Connector: g.WsFuturesConnect, + }) + if err != nil { + return err + } - return nil + // Futures connection - BTC margined + err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + URL: futuresWebsocketBtcURL, + RateLimit: gateioWebsocketRateLimit, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: func(ctx context.Context, incoming []byte) error { + return g.WsHandleFuturesData(ctx, incoming, asset.Futures) + }, + Subscriber: g.FuturesSubscribe, + Unsubscriber: g.FuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.BTC) }, + Connector: g.WsFuturesConnect, + }) + if err != nil { + return err + } + + // TODO: Add BTC margined delivery futures. + // Futures connection - Delivery - USDT margined + err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + URL: deliveryRealUSDTTradingURL, + RateLimit: gateioWebsocketRateLimit, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: func(ctx context.Context, incoming []byte) error { + return g.WsHandleFuturesData(ctx, incoming, asset.DeliveryFutures) + }, + Subscriber: g.DeliveryFuturesSubscribe, + Unsubscriber: g.DeliveryFuturesUnsubscribe, + GenerateSubscriptions: g.GenerateDeliveryFuturesDefaultSubscriptions, + Connector: g.WsDeliveryFuturesConnect, + }) + if err != nil { + return err + } + + // Futures connection - Options + return g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + URL: optionsWebsocketURL, + RateLimit: gateioWebsocketRateLimit, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: g.WsHandleOptionsData, + Subscriber: g.OptionsSubscribe, + Unsubscriber: g.OptionsUnsubscribe, + GenerateSubscriptions: g.GenerateOptionsDefaultSubscriptions, + Connector: g.WsOptionsConnect, + }) } // UpdateTicker updates and returns the ticker for a currency pair diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_ws_delivery_futures.go index 5bb71b20dd5..1d0ed495828 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_ws_delivery_futures.go @@ -71,7 +71,7 @@ func (g *Gateio) WsDeliveryFuturesConnect(ctx context.Context, conn stream.Conne } // GenerateDeliveryFuturesDefaultSubscriptions returns delivery futures default subscriptions params. -func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions(_ currency.Code) (subscription.List, error) { +func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() (subscription.List, error) { _, err := g.GetCredentials(context.Background()) if err != nil { g.Websocket.SetCanUseAuthenticatedEndpoints(false) @@ -85,31 +85,11 @@ func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions(_ currency.Code) (s futuresBalancesChannel, ) } - pairs, err := g.GetAvailablePairs(asset.DeliveryFutures) + pairs, err := g.GetEnabledPairs(asset.DeliveryFutures) if err != nil { return nil, err } - // switch { - // case settlement.Equal(currency.USDT): - // pairs, err = pairs.GetPairsByQuote(currency.USDT) - // if err != nil { - // return nil, err - // } - // case settlement.Equal(currency.BTC): - // offset := 0 - // for x := range pairs { - // if pairs[x].Quote.Equal(currency.USDT) { - // continue // skip USDT pairs - // } - // pairs[offset] = pairs[x] - // offset++ - // } - // pairs = pairs[:offset] - // default: - // return nil, fmt.Errorf("settlement currency %s not supported", settlement) - // } - var subscriptions subscription.List for i := range channelsToSubscribe { for j := range pairs { diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index b1e10de5a81..0c0f9f11ad8 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -169,6 +169,7 @@ func (g *Gateio) FuturesUnsubscribe(ctx context.Context, conn stream.Connection, // WsHandleFuturesData handles futures websocket data func (g *Gateio) WsHandleFuturesData(ctx context.Context, respRaw []byte, a asset.Item) error { + fmt.Printf("Gateio WsHandleFuturesData: %s\n", string(respRaw)) var push WsResponse err := json.Unmarshal(respRaw, &push) if err != nil { diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 2b0c765ded1..c5c4e96b1e2 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -6,7 +6,6 @@ import ( "fmt" "net" "net/url" - "os" "slices" "time" @@ -316,18 +315,22 @@ func (w *Websocket) Connect() error { return nil } - fmt.Println("NEW CONNECTOR") + // hasStableConnection is used to determine if the websocket has a stable + // connection. If it does not, the websocket will be set to disconnected. + hasStableConnection := false + defer w.setStateFromHasStableConnection(&hasStableConnection) if len(w.PendingConnections) == 0 { return fmt.Errorf("cannot connect: %w", errNoPendingConnections) } + // TODO: Implement concurrency below. This can be achieved once there is + // more mutex protection around the subscriptions. for i := range w.PendingConnections { - fmt.Println("SPAWN CONNECTION: ", i) - // if !w.PendingConnections[i].Enabled() { - // fmt.Println("Connection not enabled") - // continue - // } + if w.PendingConnections[i].GenerateSubscriptions == nil { + return fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.PendingConnections[i].URL, errWebsocketSubscriptionsGeneratorUnset) + } + subs, err := w.PendingConnections[i].GenerateSubscriptions() // regenerate state on new connection if err != nil { if errors.Is(err, asset.ErrNotEnabled) { @@ -336,42 +339,45 @@ func (w *Websocket) Connect() error { } continue // Non-fatal error, we can continue to the next connection } - w.setState(disconnectedState) return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) } - fmt.Println("subs: ", len(subs)) - - for x := range subs { - fmt.Println("SUBS: ", subs[x]) - } - - os.Exit(1) - if len(subs) == 0 { // If no subscriptions are generated, we skip the connection - log.Warnf(log.WebsocketMgr, "%s websocket: no subscriptions generated", w.exchangeName) + if w.verbose { + log.Warnf(log.WebsocketMgr, "%s websocket: no subscriptions generated", w.exchangeName) + } continue } - // TODO: Add window to max subscriptions per connection, to spawn new connections if needed. + if w.PendingConnections[i].Connector == nil { + return fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.PendingConnections[i].URL, errNoConnectFunc) + } + if w.PendingConnections[i].Handler == nil { + return fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.PendingConnections[i].URL, errWebsocketDataHandlerUnset) + } + if w.PendingConnections[i].Subscriber == nil { + return fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.PendingConnections[i].URL, errWebsocketSubscriberUnset) + } + + // TODO: Add window for max subscriptions per connection, to spawn new connections if needed. conn := w.getConnectionFromSetup(w.PendingConnections[i]) err = w.PendingConnections[i].Connector(context.TODO(), conn) if err != nil { return fmt.Errorf("%v Error connecting %w", w.exchangeName, err) } + + hasStableConnection = true + w.Wg.Add(1) go w.Reader(context.TODO(), conn, w.PendingConnections[i].Handler) - fmt.Println("Subscribing to channels: ", len(subs)) err = w.PendingConnections[i].Subscriber(context.TODO(), conn, subs) if err != nil { return fmt.Errorf("%v Error subscribing %w", w.exchangeName, err) } } - fmt.Println("DONE SPAWNING CONNECTIONS") - if !w.IsConnectionMonitorRunning() { err := w.connectionMonitor() if err != nil { @@ -382,6 +388,14 @@ func (w *Websocket) Connect() error { return nil } +func (w *Websocket) setStateFromHasStableConnection(hasStableConnection *bool) { + if *hasStableConnection { + w.setState(connectedState) + } else { + w.setState(disconnectedState) + } +} + // Disable disables the exchange websocket protocol // Note that connectionMonitor will be responsible for shutting down the websocket after disabling func (w *Websocket) Disable() error { diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index ebc5d340d03..2871b1d1d23 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -224,6 +224,9 @@ func (w *WebsocketConnection) IsConnected() bool { // ReadMessage reads messages, can handle text, gzip and binary func (w *WebsocketConnection) ReadMessage() Response { + if w.Connection == nil { + return Response{} + } mType, resp, err := w.Connection.ReadMessage() if err != nil { if IsDisconnectionError(err) { diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 6e7fa1b888f..c79671a3762 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -4,6 +4,7 @@ import ( "bytes" "compress/flate" "compress/gzip" + "context" "encoding/json" "errors" "fmt" @@ -140,21 +141,6 @@ func TestSetup(t *testing.T) { assert.ErrorIs(t, err, errConfigFeaturesIsNil, "Setup should error correctly") websocketSetup.ExchangeConfig.Features = &config.FeaturesConfig{} - err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errWebsocketConnectorUnset, "Setup should error correctly") - - websocketSetup.Connector = func() error { return nil } - err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errWebsocketSubscriberUnset, "Setup should error correctly") - - websocketSetup.Subscriber = func(subscription.List) error { return nil } - websocketSetup.Features.Unsubscribe = true - err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errWebsocketUnsubscriberUnset, "Setup should error correctly") - - websocketSetup.Unsubscriber = func(subscription.List) error { return nil } - err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset, "Setup should error correctly") websocketSetup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil } err = w.Setup(websocketSetup) @@ -296,11 +282,8 @@ func TestIsDisconnectionError(t *testing.T) { func TestConnectionMessageErrors(t *testing.T) { t.Parallel() var wsWrong = &Websocket{} - err := wsWrong.Connect() - assert.ErrorIs(t, err, errNoConnectFunc, "Connect should error correctly") - wsWrong.connector = func() error { return nil } - err = wsWrong.Connect() + err := wsWrong.Connect() assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "Connect should error correctly") wsWrong.setEnabled(true) @@ -351,6 +334,66 @@ func TestConnectionMessageErrors(t *testing.T) { ws.ReadMessageErrors <- &websocket.CloseError{Code: 1006, Text: "SpecialText"} assert.EventuallyWithT(t, c, 2*time.Second, 10*time.Millisecond, "Should get an error down the routine") + + // Test individual connection defined functions + ws.connector = nil + + err = ws.Connect() + assert.ErrorIs(t, err, errNoPendingConnections, "Connect should error correctly") + + ws.PendingConnections = []ConnectionSetup{{URL: "ws://localhost:8080/ws"}} + err = ws.Connect() + require.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset) + + ws.PendingConnections[0].GenerateSubscriptions = func() (subscription.List, error) { + return nil, errDastardlyReason + } + err = ws.Connect() + require.ErrorIs(t, err, errDastardlyReason) + + ws.PendingConnections[0].GenerateSubscriptions = func() (subscription.List, error) { + return subscription.List{{}}, nil + } + err = ws.Connect() + require.ErrorIs(t, err, errNoConnectFunc) + + ws.PendingConnections[0].Connector = func(context.Context, Connection) error { + return errDastardlyReason + } + err = ws.Connect() + require.ErrorIs(t, err, errWebsocketDataHandlerUnset) + + ws.PendingConnections[0].Handler = func(context.Context, []byte) error { + return errDastardlyReason + } + err = ws.Connect() + require.ErrorIs(t, err, errWebsocketSubscriberUnset) + + ws.PendingConnections[0].Subscriber = func(context.Context, Connection, subscription.List) error { + return errDastardlyReason + } + err = ws.Connect() + require.ErrorIs(t, err, errDastardlyReason) + + ws.PendingConnections[0].Connector = func(ctx context.Context, conn Connection) error { + return nil + } + err = ws.Connect() + require.ErrorIs(t, err, errDastardlyReason) + + ws.PendingConnections[0].Handler = func(context.Context, []byte) error { + return nil + } + require.NoError(t, ws.Shutdown()) + err = ws.Connect() + require.ErrorIs(t, err, errDastardlyReason) + + ws.PendingConnections[0].Subscriber = func(context.Context, Connection, subscription.List) error { + return nil + } + require.NoError(t, ws.Shutdown()) + err = ws.Connect() + require.NoError(t, err) } func TestWebsocket(t *testing.T) { From 31a26c074c9fbe7fa0738f54a65a0e58aa88ffb9 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Mon, 15 Jul 2024 15:29:10 +1000 Subject: [PATCH 004/138] integrate flushing and for enabling/disabling pairs from rpc shenanigans --- docs/ADD_NEW_EXCHANGE.md | 4 +- exchanges/binance/binance_websocket.go | 4 +- exchanges/binanceus/binanceus_websocket.go | 4 +- exchanges/bitfinex/bitfinex_websocket.go | 8 +- exchanges/bithumb/bithumb_websocket.go | 2 +- exchanges/bitmex/bitmex_websocket.go | 4 +- exchanges/bitstamp/bitstamp_websocket.go | 4 +- exchanges/btcmarkets/btcmarkets_websocket.go | 4 +- exchanges/btse/btse_websocket.go | 4 +- .../coinbasepro/coinbasepro_websocket.go | 4 +- exchanges/coinut/coinut_websocket.go | 4 +- exchanges/exchange.go | 4 +- exchanges/gateio/gateio_websocket.go | 4 +- exchanges/gateio/gateio_wrapper.go | 2 +- .../gateio/gateio_ws_delivery_futures.go | 81 ++--- exchanges/gateio/gateio_ws_futures.go | 82 +++-- exchanges/gateio/gateio_ws_option.go | 4 +- exchanges/gemini/gemini_websocket.go | 4 +- exchanges/hitbtc/hitbtc_websocket.go | 4 +- exchanges/huobi/huobi_websocket.go | 4 +- exchanges/kraken/kraken_websocket.go | 6 +- exchanges/kucoin/kucoin_websocket.go | 4 +- exchanges/okcoin/okcoin_websocket.go | 12 +- exchanges/okx/okx_websocket.go | 12 +- exchanges/poloniex/poloniex_websocket.go | 4 +- exchanges/stream/websocket.go | 281 ++++++++++++++---- exchanges/stream/websocket_test.go | 82 ++--- exchanges/stream/websocket_types.go | 13 +- 28 files changed, 406 insertions(+), 243 deletions(-) diff --git a/docs/ADD_NEW_EXCHANGE.md b/docs/ADD_NEW_EXCHANGE.md index d0d4d511391..834a8c3107b 100644 --- a/docs/ADD_NEW_EXCHANGE.md +++ b/docs/ADD_NEW_EXCHANGE.md @@ -837,7 +837,7 @@ channels: continue } // When we have a successful subscription, we can alert our internal management system of the success. - f.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[i]) + f.Websocket.AddSuccessfulSubscriptions(nil, channelsToSubscribe[i]) } return errs } @@ -1077,7 +1077,7 @@ channels: continue } // When we have a successful unsubscription, we can alert our internal management system of the success. - f.Websocket.RemoveSubscriptions(channelsToUnsubscribe[i]) + f.Websocket.RemoveSubscriptions(nil, channelsToUnsubscribe[i]) } if errs != nil { return errs diff --git a/exchanges/binance/binance_websocket.go b/exchanges/binance/binance_websocket.go index 1168dc664a3..ad4aecf9b4d 100644 --- a/exchanges/binance/binance_websocket.go +++ b/exchanges/binance/binance_websocket.go @@ -593,7 +593,7 @@ func (b *Binance) manageSubs(op string, subs subscription.List) error { b.Websocket.DataHandler <- err if op == wsSubscribeMethod { - if err2 := b.Websocket.RemoveSubscriptions(subs...); err2 != nil { + if err2 := b.Websocket.RemoveSubscriptions(nil, subs...); err2 != nil { err = common.AppendError(err, err2) } } @@ -601,7 +601,7 @@ func (b *Binance) manageSubs(op string, subs subscription.List) error { if op == wsSubscribeMethod { err = common.AppendError(err, subs.SetStates(subscription.SubscribedState)) } else { - err = b.Websocket.RemoveSubscriptions(subs...) + err = b.Websocket.RemoveSubscriptions(nil, subs...) } } diff --git a/exchanges/binanceus/binanceus_websocket.go b/exchanges/binanceus/binanceus_websocket.go index 47e48f167d4..6335bddcab6 100644 --- a/exchanges/binanceus/binanceus_websocket.go +++ b/exchanges/binanceus/binanceus_websocket.go @@ -590,7 +590,7 @@ func (bi *Binanceus) Subscribe(channelsToSubscribe subscription.List) error { return err } } - return bi.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe...) + return bi.Websocket.AddSuccessfulSubscriptions(nil, channelsToSubscribe...) } // Unsubscribe unsubscribes from a set of channels @@ -614,7 +614,7 @@ func (bi *Binanceus) Unsubscribe(channelsToUnsubscribe subscription.List) error return err } } - return bi.Websocket.RemoveSubscriptions(channelsToUnsubscribe...) + return bi.Websocket.RemoveSubscriptions(nil, channelsToUnsubscribe...) } func (bi *Binanceus) setupOrderbookManager() { diff --git a/exchanges/bitfinex/bitfinex_websocket.go b/exchanges/bitfinex/bitfinex_websocket.go index ed04da9988f..4fec8b42086 100644 --- a/exchanges/bitfinex/bitfinex_websocket.go +++ b/exchanges/bitfinex/bitfinex_websocket.go @@ -511,7 +511,7 @@ func (b *Bitfinex) handleWSSubscribed(respRaw []byte) error { c.Key = int(chanID) // subscribeToChan removes the old subID keyed Subscription - if err := b.Websocket.AddSuccessfulSubscriptions(c); err != nil { + if err := b.Websocket.AddSuccessfulSubscriptions(nil, c); err != nil { return fmt.Errorf("%w: %w subID: %s", stream.ErrSubscriptionFailure, err, subID) } @@ -1660,7 +1660,7 @@ func (b *Bitfinex) resubOrderbook(c *subscription.Subscription) error { // Resub will block so we have to do this in a goro go func() { - if err := b.Websocket.ResubscribeToChannel(c); err != nil { + if err := b.Websocket.ResubscribeToChannel(nil, c); err != nil { log.Errorf(log.ExchangeSys, "%s error resubscribing orderbook: %v", b.Name, err) } }() @@ -1753,7 +1753,7 @@ func (b *Bitfinex) subscribeToChan(chans subscription.List) error { // Always remove the temporary subscription keyed by subID defer func() { - _ = b.Websocket.RemoveSubscriptions(c) + _ = b.Websocket.RemoveSubscriptions(nil, c) }() respRaw, err := b.Websocket.Conn.SendMessageReturnResponse("subscribe:"+subID, req) @@ -1860,7 +1860,7 @@ func (b *Bitfinex) unsubscribeFromChan(chans subscription.List) error { return wErr } - return b.Websocket.RemoveSubscriptions(c) + return b.Websocket.RemoveSubscriptions(nil, c) } // getErrResp takes a json response string and looks for an error event type diff --git a/exchanges/bithumb/bithumb_websocket.go b/exchanges/bithumb/bithumb_websocket.go index 929d68a61e2..01c61cdeaff 100644 --- a/exchanges/bithumb/bithumb_websocket.go +++ b/exchanges/bithumb/bithumb_websocket.go @@ -205,7 +205,7 @@ func (b *Bithumb) Subscribe(channelsToSubscribe subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(req) if err == nil { - err = b.Websocket.AddSuccessfulSubscriptions(s) + err = b.Websocket.AddSuccessfulSubscriptions(nil, s) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/bitmex/bitmex_websocket.go b/exchanges/bitmex/bitmex_websocket.go index 6779a97a1e9..dce46c9d4e0 100644 --- a/exchanges/bitmex/bitmex_websocket.go +++ b/exchanges/bitmex/bitmex_websocket.go @@ -601,7 +601,7 @@ func (b *Bitmex) Subscribe(subs subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(req) if err == nil { - err = b.Websocket.AddSuccessfulSubscriptions(subs...) + err = b.Websocket.AddSuccessfulSubscriptions(nil, subs...) } return err } @@ -620,7 +620,7 @@ func (b *Bitmex) Unsubscribe(subs subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(req) if err == nil { - err = b.Websocket.RemoveSubscriptions(subs...) + err = b.Websocket.RemoveSubscriptions(nil, subs...) } return err } diff --git a/exchanges/bitstamp/bitstamp_websocket.go b/exchanges/bitstamp/bitstamp_websocket.go index ab887e422a3..275090be2d3 100644 --- a/exchanges/bitstamp/bitstamp_websocket.go +++ b/exchanges/bitstamp/bitstamp_websocket.go @@ -294,7 +294,7 @@ func (b *Bitstamp) Subscribe(channelsToSubscribe subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(req) if err == nil { - err = b.Websocket.AddSuccessfulSubscriptions(s) + err = b.Websocket.AddSuccessfulSubscriptions(nil, s) } if err != nil { errs = common.AppendError(errs, err) @@ -316,7 +316,7 @@ func (b *Bitstamp) Unsubscribe(channelsToUnsubscribe subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(req) if err == nil { - err = b.Websocket.RemoveSubscriptions(s) + err = b.Websocket.RemoveSubscriptions(nil, s) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/btcmarkets/btcmarkets_websocket.go b/exchanges/btcmarkets/btcmarkets_websocket.go index 18a715ae848..4c817a3bc25 100644 --- a/exchanges/btcmarkets/btcmarkets_websocket.go +++ b/exchanges/btcmarkets/btcmarkets_websocket.go @@ -376,7 +376,7 @@ func (b *BTCMarkets) Subscribe(subs subscription.List) error { err := b.Websocket.Conn.SendJSONMessage(r) if err == nil { - err = b.Websocket.AddSuccessfulSubscriptions(s) + err = b.Websocket.AddSuccessfulSubscriptions(nil, s) } if err != nil { errs = common.AppendError(errs, err) @@ -416,7 +416,7 @@ func (b *BTCMarkets) Unsubscribe(subs subscription.List) error { err := b.Websocket.Conn.SendJSONMessage(req) if err == nil { - err = b.Websocket.RemoveSubscriptions(s) + err = b.Websocket.RemoveSubscriptions(nil, s) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/btse/btse_websocket.go b/exchanges/btse/btse_websocket.go index 4bac49517bb..f5af037b10c 100644 --- a/exchanges/btse/btse_websocket.go +++ b/exchanges/btse/btse_websocket.go @@ -394,7 +394,7 @@ func (b *BTSE) Subscribe(channelsToSubscribe subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(sub) if err == nil { - err = b.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe...) + err = b.Websocket.AddSuccessfulSubscriptions(nil, channelsToSubscribe...) } return err } @@ -409,7 +409,7 @@ func (b *BTSE) Unsubscribe(channelsToUnsubscribe subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(unSub) if err == nil { - err = b.Websocket.RemoveSubscriptions(channelsToUnsubscribe...) + err = b.Websocket.RemoveSubscriptions(nil, channelsToUnsubscribe...) } return err } diff --git a/exchanges/coinbasepro/coinbasepro_websocket.go b/exchanges/coinbasepro/coinbasepro_websocket.go index 4765eba59b5..2476be67db2 100644 --- a/exchanges/coinbasepro/coinbasepro_websocket.go +++ b/exchanges/coinbasepro/coinbasepro_websocket.go @@ -425,7 +425,7 @@ func (c *CoinbasePro) Subscribe(subs subscription.List) error { } err := c.Websocket.Conn.SendJSONMessage(r) if err == nil { - err = c.Websocket.AddSuccessfulSubscriptions(subs...) + err = c.Websocket.AddSuccessfulSubscriptions(nil, subs...) } return err } @@ -461,7 +461,7 @@ func (c *CoinbasePro) Unsubscribe(subs subscription.List) error { } err := c.Websocket.Conn.SendJSONMessage(r) if err == nil { - err = c.Websocket.RemoveSubscriptions(subs...) + err = c.Websocket.RemoveSubscriptions(nil, subs...) } return err } diff --git a/exchanges/coinut/coinut_websocket.go b/exchanges/coinut/coinut_websocket.go index b87792aede2..7f1bb3e139f 100644 --- a/exchanges/coinut/coinut_websocket.go +++ b/exchanges/coinut/coinut_websocket.go @@ -620,7 +620,7 @@ func (c *COINUT) Subscribe(subs subscription.List) error { } err = c.Websocket.Conn.SendJSONMessage(subscribe) if err == nil { - err = c.Websocket.AddSuccessfulSubscriptions(s) + err = c.Websocket.AddSuccessfulSubscriptions(nil, s) } if err != nil { errs = common.AppendError(errs, err) @@ -663,7 +663,7 @@ func (c *COINUT) Unsubscribe(channelToUnsubscribe subscription.List) error { case len(val) == 0, val[0] != "OK": err = common.AppendError(errs, fmt.Errorf("%v unsubscribe failed for channel %v", c.Name, s.Channel)) default: - err = c.Websocket.RemoveSubscriptions(s) + err = c.Websocket.RemoveSubscriptions(nil, s) } } if err != nil { diff --git a/exchanges/exchange.go b/exchanges/exchange.go index 318f2d0c84e..4876ea8a989 100644 --- a/exchanges/exchange.go +++ b/exchanges/exchange.go @@ -1131,7 +1131,7 @@ func (b *Base) SubscribeToWebsocketChannels(channels subscription.List) error { if b.Websocket == nil { return common.ErrFunctionNotSupported } - return b.Websocket.SubscribeToChannels(channels) + return b.Websocket.SubscribeToChannels(nil, channels) } // UnsubscribeToWebsocketChannels removes from ChannelsToSubscribe @@ -1140,7 +1140,7 @@ func (b *Base) UnsubscribeToWebsocketChannels(channels subscription.List) error if b.Websocket == nil { return common.ErrFunctionNotSupported } - return b.Websocket.UnsubscribeChannels(channels) + return b.Websocket.UnsubscribeChannels(nil, channels) } // GetSubscriptions returns a copied list of subscriptions diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index b68b85ee60d..1ed473f5dd2 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -697,9 +697,9 @@ func (g *Gateio) handleSubscription(ctx context.Context, conn stream.Connection, continue } if payloads[k].Event == "subscribe" { - err = g.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[k]) + err = g.Websocket.AddSuccessfulSubscriptions(conn, channelsToSubscribe[k]) } else { - err = g.Websocket.RemoveSubscriptions(channelsToSubscribe[k]) + err = g.Websocket.RemoveSubscriptions(conn, channelsToSubscribe[k]) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 030a3b3d466..cdeeb96ff90 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -91,12 +91,12 @@ func (g *Gateio) SetDefaults() { OrderbookFetching: true, TradeFetching: true, KlineFetching: true, - FullPayloadSubscribe: true, AuthenticatedEndpoints: true, MessageCorrelation: true, GetOrder: true, AccountBalance: true, Subscribe: true, + Unsubscribe: true, }, WithdrawPermissions: exchange.AutoWithdrawCrypto | exchange.NoFiatWithdrawals, diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_ws_delivery_futures.go index 1d0ed495828..6d0dfb71248 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_ws_delivery_futures.go @@ -7,7 +7,6 @@ import ( "fmt" "net/http" "strconv" - "strings" "time" "github.com/gorilla/websocket" @@ -133,40 +132,36 @@ func (g *Gateio) handleDeliveryFuturesSubscription(ctx context.Context, conn str } var errs error var respByte []byte - // con represents the websocket connection. 0 - for usdt settle and 1 - for btc settle connections. - for con, val := range payloads { - for k := range val { - if con == 0 { - respByte, err = conn.SendMessageReturnResponse(val[k].ID, val[k]) + for i, val := range payloads { + respByte, err = conn.SendMessageReturnResponse(val.ID, val) + if err != nil { + errs = common.AppendError(errs, err) + continue + } + var resp WsEventResponse + if err = json.Unmarshal(respByte, &resp); err != nil { + errs = common.AppendError(errs, err) + } else { + if resp.Error != nil && resp.Error.Code != 0 { + errs = common.AppendError(errs, fmt.Errorf("error while %s to channel %s error code: %d message: %s", val.Event, val.Channel, resp.Error.Code, resp.Error.Message)) + continue + } + if val.Event == "subscribe" { + err = g.Websocket.AddSuccessfulSubscriptions(conn, channelsToSubscribe[i]) } else { - // TODO: Split into two. - respByte, err = conn.SendMessageReturnResponse(val[k].ID, val[k]) + err = g.Websocket.RemoveSubscriptions(conn, channelsToSubscribe[i]) } if err != nil { errs = common.AppendError(errs, err) - continue - } - var resp WsEventResponse - if err = json.Unmarshal(respByte, &resp); err != nil { - errs = common.AppendError(errs, err) - } else { - if resp.Error != nil && resp.Error.Code != 0 { - errs = common.AppendError(errs, fmt.Errorf("error while %s to channel %s error code: %d message: %s", val[k].Event, val[k].Channel, resp.Error.Code, resp.Error.Message)) - continue - } - if err = g.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[k]); err != nil { - errs = common.AppendError(errs, err) - } } } } return errs } -func (g *Gateio) generateDeliveryFuturesPayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([2][]WsInput, error) { - payloads := [2][]WsInput{} +func (g *Gateio) generateDeliveryFuturesPayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { if len(channelsToSubscribe) == 0 { - return payloads, errors.New("cannot generate payload, no channels supplied") + return nil, errors.New("cannot generate payload, no channels supplied") } var creds *account.Credentials var err error @@ -176,9 +171,10 @@ func (g *Gateio) generateDeliveryFuturesPayload(ctx context.Context, conn stream g.Websocket.SetCanUseAuthenticatedEndpoints(false) } } + var outbound []WsInput for i := range channelsToSubscribe { if len(channelsToSubscribe[i].Pairs) != 1 { - return payloads, subscription.ErrNotSinglePair + return nil, subscription.ErrNotSinglePair } var auth *WsAuthInput timestamp := time.Now() @@ -198,7 +194,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(ctx context.Context, conn stream var sigTemp string sigTemp, err = g.generateWsSignature(creds.Secret, event, channelsToSubscribe[i].Channel, timestamp) if err != nil { - return [2][]WsInput{}, err + return nil, err } auth = &WsAuthInput{ Method: "api_key", @@ -212,7 +208,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(ctx context.Context, conn stream var frequencyString string frequencyString, err = g.GetIntervalString(frequency) if err != nil { - return payloads, err + return nil, err } params = append(params, frequencyString) } @@ -235,7 +231,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(ctx context.Context, conn stream var intervalString string intervalString, err = g.GetIntervalString(interval) if err != nil { - return payloads, err + return nil, err } params = append([]string{intervalString}, params...) } @@ -245,25 +241,14 @@ func (g *Gateio) generateDeliveryFuturesPayload(ctx context.Context, conn stream params = append(params, intervalString) } } - if strings.HasPrefix(channelsToSubscribe[i].Pairs[0].Quote.Upper().String(), "USDT") { - payloads[0] = append(payloads[0], WsInput{ - ID: conn.GenerateMessageID(false), - Event: event, - Channel: channelsToSubscribe[i].Channel, - Payload: params, - Auth: auth, - Time: timestamp.Unix(), - }) - } else { - payloads[1] = append(payloads[1], WsInput{ - ID: conn.GenerateMessageID(false), - Event: event, - Channel: channelsToSubscribe[i].Channel, - Payload: params, - Auth: auth, - Time: timestamp.Unix(), - }) - } + outbound = append(outbound, WsInput{ + ID: conn.GenerateMessageID(false), + Event: event, + Channel: channelsToSubscribe[i].Channel, + Payload: params, + Auth: auth, + Time: timestamp.Unix(), + }) } - return payloads, nil + return outbound, nil } diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index 0c0f9f11ad8..5c44093f283 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -169,7 +169,6 @@ func (g *Gateio) FuturesUnsubscribe(ctx context.Context, conn stream.Connection, // WsHandleFuturesData handles futures websocket data func (g *Gateio) WsHandleFuturesData(ctx context.Context, respRaw []byte, a asset.Item) error { - fmt.Printf("Gateio WsHandleFuturesData: %s\n", string(respRaw)) var push WsResponse err := json.Unmarshal(respRaw, &push) if err != nil { @@ -236,30 +235,27 @@ func (g *Gateio) handleFuturesSubscription(ctx context.Context, conn stream.Conn } var errs error var respByte []byte - // con represents the websocket connection. 0 - for usdt settle and 1 - for btc settle connections. - for con, val := range payloads { - for k := range val { - if con == 0 { - respByte, err = conn.SendMessageReturnResponse(val[k].ID, val[k]) + for i, val := range payloads { + respByte, err = conn.SendMessageReturnResponse(val.ID, val) + if err != nil { + errs = common.AppendError(errs, err) + continue + } + var resp WsEventResponse + if err = json.Unmarshal(respByte, &resp); err != nil { + errs = common.AppendError(errs, err) + } else { + if resp.Error != nil && resp.Error.Code != 0 { + errs = common.AppendError(errs, fmt.Errorf("error while %s to channel %s error code: %d message: %s", val.Event, val.Channel, resp.Error.Code, resp.Error.Message)) + continue + } + if val.Event == "subscribe" { + err = g.Websocket.AddSuccessfulSubscriptions(conn, channelsToSubscribe[i]) } else { - // TODO: Authconn test temp - respByte, err = conn.SendMessageReturnResponse(val[k].ID, val[k]) + err = g.Websocket.RemoveSubscriptions(conn, channelsToSubscribe[i]) } if err != nil { errs = common.AppendError(errs, err) - continue - } - var resp WsEventResponse - if err = json.Unmarshal(respByte, &resp); err != nil { - errs = common.AppendError(errs, err) - } else { - if resp.Error != nil && resp.Error.Code != 0 { - errs = common.AppendError(errs, fmt.Errorf("error while %s to channel %s error code: %d message: %s", val[k].Event, val[k].Channel, resp.Error.Code, resp.Error.Message)) - continue - } - if err = g.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[k]); err != nil { - errs = common.AppendError(errs, err) - } } } } @@ -269,10 +265,9 @@ func (g *Gateio) handleFuturesSubscription(ctx context.Context, conn stream.Conn return nil } -func (g *Gateio) generateFuturesPayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([2][]WsInput, error) { - payloads := [2][]WsInput{} +func (g *Gateio) generateFuturesPayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { if len(channelsToSubscribe) == 0 { - return payloads, errors.New("cannot generate payload, no channels supplied") + return nil, errors.New("cannot generate payload, no channels supplied") } var creds *account.Credentials var err error @@ -282,9 +277,11 @@ func (g *Gateio) generateFuturesPayload(ctx context.Context, conn stream.Connect g.Websocket.SetCanUseAuthenticatedEndpoints(false) } } + + var outbound []WsInput for i := range channelsToSubscribe { if len(channelsToSubscribe[i].Pairs) != 1 { - return payloads, subscription.ErrNotSinglePair + return nil, subscription.ErrNotSinglePair } var auth *WsAuthInput timestamp := time.Now() @@ -306,7 +303,7 @@ func (g *Gateio) generateFuturesPayload(ctx context.Context, conn stream.Connect var sigTemp string sigTemp, err = g.generateWsSignature(creds.Secret, event, channelsToSubscribe[i].Channel, timestamp) if err != nil { - return [2][]WsInput{}, err + return nil, err } auth = &WsAuthInput{ Method: "api_key", @@ -320,7 +317,7 @@ func (g *Gateio) generateFuturesPayload(ctx context.Context, conn stream.Connect var frequencyString string frequencyString, err = g.GetIntervalString(frequency) if err != nil { - return payloads, err + return nil, err } params = append(params, frequencyString) } @@ -343,7 +340,7 @@ func (g *Gateio) generateFuturesPayload(ctx context.Context, conn stream.Connect var intervalString string intervalString, err = g.GetIntervalString(interval) if err != nil { - return payloads, err + return nil, err } params = append([]string{intervalString}, params...) } @@ -353,27 +350,16 @@ func (g *Gateio) generateFuturesPayload(ctx context.Context, conn stream.Connect params = append(params, intervalString) } } - if strings.HasPrefix(channelsToSubscribe[i].Pairs[0].Quote.Upper().String(), "USDT") { - payloads[0] = append(payloads[0], WsInput{ - ID: conn.GenerateMessageID(false), - Event: event, - Channel: channelsToSubscribe[i].Channel, - Payload: params, - Auth: auth, - Time: timestamp.Unix(), - }) - } else { - payloads[1] = append(payloads[1], WsInput{ - ID: conn.GenerateMessageID(false), - Event: event, - Channel: channelsToSubscribe[i].Channel, - Payload: params, - Auth: auth, - Time: timestamp.Unix(), - }) - } + outbound = append(outbound, WsInput{ + ID: conn.GenerateMessageID(false), + Event: event, + Channel: channelsToSubscribe[i].Channel, + Payload: params, + Auth: auth, + Time: timestamp.Unix(), + }) } - return payloads, nil + return outbound, nil } func (g *Gateio) processFuturesTickers(data []byte, assetType asset.Item) error { diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_ws_option.go index 949b98ce2f4..51a396b8427 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_ws_option.go @@ -316,9 +316,9 @@ func (g *Gateio) handleOptionsSubscription(ctx context.Context, conn stream.Conn continue } if payloads[k].Event == "subscribe" { - err = g.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[k]) + err = g.Websocket.AddSuccessfulSubscriptions(conn, channelsToSubscribe[k]) } else { - err = g.Websocket.RemoveSubscriptions(channelsToSubscribe[k]) + err = g.Websocket.RemoveSubscriptions(conn, channelsToSubscribe[k]) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/gemini/gemini_websocket.go b/exchanges/gemini/gemini_websocket.go index 566611bce1a..dfb24b2cadf 100644 --- a/exchanges/gemini/gemini_websocket.go +++ b/exchanges/gemini/gemini_websocket.go @@ -117,10 +117,10 @@ func (g *Gemini) manageSubs(subs subscription.List, op wsSubOp) error { } if op == wsUnsubscribeOp { - return g.Websocket.RemoveSubscriptions(subs...) + return g.Websocket.RemoveSubscriptions(nil, subs...) } - return g.Websocket.AddSuccessfulSubscriptions(subs...) + return g.Websocket.AddSuccessfulSubscriptions(nil, subs...) } // WsAuth will connect to Gemini's secure endpoint diff --git a/exchanges/hitbtc/hitbtc_websocket.go b/exchanges/hitbtc/hitbtc_websocket.go index c361daaa2f5..38d0f6455bb 100644 --- a/exchanges/hitbtc/hitbtc_websocket.go +++ b/exchanges/hitbtc/hitbtc_websocket.go @@ -526,7 +526,7 @@ func (h *HitBTC) Subscribe(channelsToSubscribe subscription.List) error { err := h.Websocket.Conn.SendJSONMessage(r) if err == nil { - err = h.Websocket.AddSuccessfulSubscriptions(s) + err = h.Websocket.AddSuccessfulSubscriptions(nil, s) } if err != nil { errs = common.AppendError(errs, err) @@ -562,7 +562,7 @@ func (h *HitBTC) Unsubscribe(subs subscription.List) error { err := h.Websocket.Conn.SendJSONMessage(r) if err == nil { - err = h.Websocket.RemoveSubscriptions(s) + err = h.Websocket.RemoveSubscriptions(nil, s) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/huobi/huobi_websocket.go b/exchanges/huobi/huobi_websocket.go index 4a4e4d7adf7..6ae741854e4 100644 --- a/exchanges/huobi/huobi_websocket.go +++ b/exchanges/huobi/huobi_websocket.go @@ -570,7 +570,7 @@ func (h *HUOBI) Subscribe(channelsToSubscribe subscription.List) error { }) } if err == nil { - err = h.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[i]) + err = h.Websocket.AddSuccessfulSubscriptions(nil, channelsToSubscribe[i]) } if err != nil { errs = common.AppendError(errs, err) @@ -604,7 +604,7 @@ func (h *HUOBI) Unsubscribe(channelsToUnsubscribe subscription.List) error { }) } if err == nil { - err = h.Websocket.RemoveSubscriptions(channelsToUnsubscribe[i]) + err = h.Websocket.RemoveSubscriptions(nil, channelsToUnsubscribe[i]) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/kraken/kraken_websocket.go b/exchanges/kraken/kraken_websocket.go index eedccc8ff49..b1e41d31fb8 100644 --- a/exchanges/kraken/kraken_websocket.go +++ b/exchanges/kraken/kraken_websocket.go @@ -800,7 +800,7 @@ func (k *Kraken) wsProcessOrderBook(channelData *WebsocketChannelData, data map[ go func(resub *subscription.Subscription) { // This was locking the main websocket reader routine and a // backlog occurred. So put this into it's own go routine. - errResub := k.Websocket.ResubscribeToChannel(resub) + errResub := k.Websocket.ResubscribeToChannel(nil, resub) if errResub != nil { log.Errorf(log.WebsocketMgr, "resubscription failure for %v: %v", @@ -1235,7 +1235,7 @@ channels: _, err = k.Websocket.Conn.SendMessageReturnResponse((*subs)[i].RequestID, (*subs)[i]) } if err == nil { - err = k.Websocket.AddSuccessfulSubscriptions((*subs)[i].Channels...) + err = k.Websocket.AddSuccessfulSubscriptions(nil, (*subs)[i].Channels...) } if err != nil { errs = common.AppendError(errs, err) @@ -1294,7 +1294,7 @@ channels: _, err = k.Websocket.Conn.SendMessageReturnResponse(unsubs[i].RequestID, unsubs[i]) } if err == nil { - err = k.Websocket.RemoveSubscriptions(unsubs[i].Channels...) + err = k.Websocket.RemoveSubscriptions(nil, unsubs[i].Channels...) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/kucoin/kucoin_websocket.go b/exchanges/kucoin/kucoin_websocket.go index 5b7c14bde48..8feb1787b6f 100644 --- a/exchanges/kucoin/kucoin_websocket.go +++ b/exchanges/kucoin/kucoin_websocket.go @@ -1007,9 +1007,9 @@ func (ku *Kucoin) manageSubscriptions(subs subscription.List, operation string) errs = common.AppendError(errs, fmt.Errorf("%w: %s from %s", errInvalidMsgType, rType, respRaw)) default: if operation == "unsubscribe" { - err = ku.Websocket.RemoveSubscriptions(s) + err = ku.Websocket.RemoveSubscriptions(nil, s) } else { - err = ku.Websocket.AddSuccessfulSubscriptions(s) + err = ku.Websocket.AddSuccessfulSubscriptions(nil, s) if ku.Verbose { log.Debugf(log.ExchangeSys, "%s Subscribed to Channel: %s", ku.Name, s.Channel) } diff --git a/exchanges/okcoin/okcoin_websocket.go b/exchanges/okcoin/okcoin_websocket.go index b5af0f29b83..30a39aa9c06 100644 --- a/exchanges/okcoin/okcoin_websocket.go +++ b/exchanges/okcoin/okcoin_websocket.go @@ -931,15 +931,15 @@ func (o *Okcoin) manageSubscriptions(operation string, subs subscription.List) e if operation == "unsubscribe" { if authenticatedChannelSubscription { - err = o.Websocket.RemoveSubscriptions(authChannels...) + err = o.Websocket.RemoveSubscriptions(nil, authChannels...) } else { - err = o.Websocket.RemoveSubscriptions(channels...) + err = o.Websocket.RemoveSubscriptions(nil, channels...) } } else { if authenticatedChannelSubscription { - err = o.Websocket.AddSuccessfulSubscriptions(authChannels...) + err = o.Websocket.AddSuccessfulSubscriptions(nil, authChannels...) } else { - err = o.Websocket.AddSuccessfulSubscriptions(channels...) + err = o.Websocket.AddSuccessfulSubscriptions(nil, channels...) } } if err != nil { @@ -974,9 +974,9 @@ func (o *Okcoin) manageSubscriptions(operation string, subs subscription.List) e } } if operation == "unsubscribe" { - return o.Websocket.RemoveSubscriptions(channels...) + return o.Websocket.RemoveSubscriptions(nil, channels...) } - return o.Websocket.AddSuccessfulSubscriptions(channels...) + return o.Websocket.AddSuccessfulSubscriptions(nil, channels...) } // GetCandlesData represents a candlestick instances list. diff --git a/exchanges/okx/okx_websocket.go b/exchanges/okx/okx_websocket.go index ede276a4c97..3ff3660c918 100644 --- a/exchanges/okx/okx_websocket.go +++ b/exchanges/okx/okx_websocket.go @@ -486,9 +486,9 @@ func (ok *Okx) handleSubscription(operation string, subscriptions subscription.L return err } if operation == operationUnsubscribe { - err = ok.Websocket.RemoveSubscriptions(channels...) + err = ok.Websocket.RemoveSubscriptions(nil, channels...) } else { - err = ok.Websocket.AddSuccessfulSubscriptions(channels...) + err = ok.Websocket.AddSuccessfulSubscriptions(nil, channels...) } if err != nil { return err @@ -510,9 +510,9 @@ func (ok *Okx) handleSubscription(operation string, subscriptions subscription.L return err } if operation == operationUnsubscribe { - err = ok.Websocket.RemoveSubscriptions(channels...) + err = ok.Websocket.RemoveSubscriptions(nil, channels...) } else { - err = ok.Websocket.AddSuccessfulSubscriptions(channels...) + err = ok.Websocket.AddSuccessfulSubscriptions(nil, channels...) } if err != nil { return err @@ -538,10 +538,10 @@ func (ok *Okx) handleSubscription(operation string, subscriptions subscription.L channels = append(channels, authChannels...) if operation == operationUnsubscribe { - return ok.Websocket.RemoveSubscriptions(channels...) + return ok.Websocket.RemoveSubscriptions(nil, channels...) } - return ok.Websocket.AddSuccessfulSubscriptions(channels...) + return ok.Websocket.AddSuccessfulSubscriptions(nil, channels...) } // WsHandleData will read websocket raw data and pass to appropriate handler diff --git a/exchanges/poloniex/poloniex_websocket.go b/exchanges/poloniex/poloniex_websocket.go index a5407915519..ebce7241318 100644 --- a/exchanges/poloniex/poloniex_websocket.go +++ b/exchanges/poloniex/poloniex_websocket.go @@ -608,9 +608,9 @@ func (p *Poloniex) manageSubs(subs subscription.List, op wsOp) error { } if err == nil { if op == wsSubscribeOp { - err = p.Websocket.AddSuccessfulSubscriptions(s) + err = p.Websocket.AddSuccessfulSubscriptions(nil, s) } else { - err = p.Websocket.RemoveSubscriptions(s) + err = p.Websocket.RemoveSubscriptions(nil, s) } } if err != nil { diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index c5c4e96b1e2..c50184f4561 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/url" + "os" "slices" "time" @@ -90,6 +91,7 @@ func NewWebsocket() *Websocket { subscriptions: subscription.NewStore(), features: &protocol.Features{}, Orderbook: buffer.Orderbook{}, + Connections: make(map[Connection]ConnectionAssociation), } } @@ -211,7 +213,6 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { // functions are defined per connection. Else we use the global connector // and supporting functions for backwards compatibility. if w.connector == nil { - fmt.Println("w.connector == nil") if c.Handler == nil { return fmt.Errorf("%w: %w", errConnSetup, errWebsocketDataHandlerUnset) } @@ -227,14 +228,14 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { if c.Connector == nil { return fmt.Errorf("%w: %w", errConnSetup, errWebsocketConnectorUnset) } - w.PendingConnections = append(w.PendingConnections, c) + w.ConnectionManager = append(w.ConnectionManager, ConnectionDetails{Details: &c}) return nil } if c.Authenticated { - w.AuthConn = w.getConnectionFromSetup(c) + w.AuthConn = w.getConnectionFromSetup(&c) } else { - w.Conn = w.getConnectionFromSetup(c) + w.Conn = w.getConnectionFromSetup(&c) } return nil @@ -242,7 +243,7 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { // getConnectionFromSetup returns a websocket connection from a setup // configuration. This is used for setting up new connections on the fly. -func (w *Websocket) getConnectionFromSetup(c ConnectionSetup) *WebsocketConnection { +func (w *Websocket) getConnectionFromSetup(c *ConnectionSetup) *WebsocketConnection { connectionURL := w.GetWebsocketURL() if c.URL != "" { connectionURL = c.URL @@ -284,6 +285,14 @@ func (w *Websocket) Connect() error { } w.subscriptions.Clear() + for _, details := range w.Connections { + if details.Subscriptions == nil { + return fmt.Errorf("%w: subscriptions", common.ErrNilPointer) + } + details.Subscriptions.Clear() + } + w.Connections = make(map[Connection]ConnectionAssociation) + w.dataMonitor() w.trafficMonitor() w.setState(connectingState) @@ -308,7 +317,7 @@ func (w *Websocket) Connect() error { return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) } if len(subs) != 0 { - if err := w.SubscribeToChannels(subs); err != nil { + if err := w.SubscribeToChannels(nil, subs); err != nil { return err } } @@ -320,18 +329,18 @@ func (w *Websocket) Connect() error { hasStableConnection := false defer w.setStateFromHasStableConnection(&hasStableConnection) - if len(w.PendingConnections) == 0 { + if len(w.ConnectionManager) == 0 { return fmt.Errorf("cannot connect: %w", errNoPendingConnections) } // TODO: Implement concurrency below. This can be achieved once there is // more mutex protection around the subscriptions. - for i := range w.PendingConnections { - if w.PendingConnections[i].GenerateSubscriptions == nil { - return fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.PendingConnections[i].URL, errWebsocketSubscriptionsGeneratorUnset) + for i := range w.ConnectionManager { + if w.ConnectionManager[i].Details.GenerateSubscriptions == nil { + return fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.ConnectionManager[i].Details.URL, errWebsocketSubscriptionsGeneratorUnset) } - subs, err := w.PendingConnections[i].GenerateSubscriptions() // regenerate state on new connection + subs, err := w.ConnectionManager[i].Details.GenerateSubscriptions() // regenerate state on new connection if err != nil { if errors.Is(err, asset.ErrNotEnabled) { if w.verbose { @@ -350,19 +359,21 @@ func (w *Websocket) Connect() error { continue } - if w.PendingConnections[i].Connector == nil { - return fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.PendingConnections[i].URL, errNoConnectFunc) + if w.ConnectionManager[i].Details.Connector == nil { + return fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.ConnectionManager[i].Details.URL, errNoConnectFunc) } - if w.PendingConnections[i].Handler == nil { - return fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.PendingConnections[i].URL, errWebsocketDataHandlerUnset) + if w.ConnectionManager[i].Details.Handler == nil { + return fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.ConnectionManager[i].Details.URL, errWebsocketDataHandlerUnset) } - if w.PendingConnections[i].Subscriber == nil { - return fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.PendingConnections[i].URL, errWebsocketSubscriberUnset) + if w.ConnectionManager[i].Details.Subscriber == nil { + return fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.ConnectionManager[i].Details.URL, errWebsocketSubscriberUnset) } // TODO: Add window for max subscriptions per connection, to spawn new connections if needed. - conn := w.getConnectionFromSetup(w.PendingConnections[i]) - err = w.PendingConnections[i].Connector(context.TODO(), conn) + + conn := w.getConnectionFromSetup(w.ConnectionManager[i].Details) + + err = w.ConnectionManager[i].Details.Connector(context.TODO(), conn) if err != nil { return fmt.Errorf("%v Error connecting %w", w.exchangeName, err) } @@ -370,12 +381,19 @@ func (w *Websocket) Connect() error { hasStableConnection = true w.Wg.Add(1) - go w.Reader(context.TODO(), conn, w.PendingConnections[i].Handler) + go w.Reader(context.TODO(), conn, w.ConnectionManager[i].Details.Handler) + + w.Connections[conn] = ConnectionAssociation{ + Subscriptions: subscription.NewStore(), + Details: w.ConnectionManager[i].Details, + } - err = w.PendingConnections[i].Subscriber(context.TODO(), conn, subs) + err = w.ConnectionManager[i].Details.Subscriber(context.TODO(), conn, subs) if err != nil { return fmt.Errorf("%v Error subscribing %w", w.exchangeName, err) } + + w.ConnectionManager[i].Connection = conn } if !w.IsConnectionMonitorRunning() { @@ -535,6 +553,14 @@ func (w *Websocket) Shutdown() error { defer w.Orderbook.FlushBuffer() + for conn, details := range w.Connections { + if err := conn.Shutdown(); err != nil { + return err + } + details.Subscriptions.Clear() + } + w.Connections = make(map[Connection]ConnectionAssociation) + if w.Conn != nil { if err := w.Conn.Shutdown(); err != nil { return err @@ -572,40 +598,91 @@ func (w *Websocket) FlushChannels() error { } if w.features.Subscribe { - newsubs, err := w.GenerateSubs() - if err != nil { - return err - } + if w.GenerateSubs != nil { + newsubs, err := w.GenerateSubs() + if err != nil { + return err + } - subs, unsubs := w.GetChannelDifference(newsubs) - if w.features.Unsubscribe { - if len(unsubs) != 0 { - err := w.UnsubscribeChannels(unsubs) + subs, unsubs := w.GetChannelDifference(nil, newsubs) + if len(unsubs) != 0 && w.features.Unsubscribe { + err := w.UnsubscribeChannels(nil, unsubs) if err != nil { return err } } + if len(subs) < 1 { + return nil + } + return w.SubscribeToChannels(nil, subs) } + for x := range w.ConnectionManager { + if w.ConnectionManager[x].Details.GenerateSubscriptions == nil { + continue + } + newsubs, err := w.ConnectionManager[x].Details.GenerateSubscriptions() + if err != nil { + if errors.Is(err, asset.ErrNotEnabled) { + continue + } + return err + } + subs, unsubs := w.GetChannelDifference(w.ConnectionManager[x].Connection, newsubs) + if len(unsubs) != 0 && w.features.Unsubscribe { - if len(subs) < 1 { - return nil + err := w.UnsubscribeChannels(w.ConnectionManager[x].Connection, unsubs) + if err != nil { + return err + } + } + if len(subs) != 0 { + err = w.SubscribeToChannels(w.ConnectionManager[x].Connection, subs) + if err != nil { + return err + } + } } - return w.SubscribeToChannels(subs) + return nil } else if w.features.FullPayloadSubscribe { // FullPayloadSubscribe means that the endpoint requires all // subscriptions to be sent via the websocket connection e.g. if you are // subscribed to ticker and orderbook but require trades as well, you // would need to send ticker, orderbook and trades channel subscription // messages. - newsubs, err := w.GenerateSubs() - if err != nil { - return err + + if w.GenerateSubs != nil { + newsubs, err := w.GenerateSubs() + if err != nil { + return err + } + + if len(newsubs) != 0 { + // Purge subscription list as there will be conflicts + w.subscriptions.Clear() + return w.SubscribeToChannels(nil, newsubs) + } + return nil } - if len(newsubs) != 0 { - // Purge subscription list as there will be conflicts - w.subscriptions.Clear() - return w.SubscribeToChannels(newsubs) + for x := range w.ConnectionManager { + if w.ConnectionManager[x].Details.GenerateSubscriptions == nil { + continue + } + newsubs, err := w.ConnectionManager[x].Details.GenerateSubscriptions() + if err != nil { + if errors.Is(err, asset.ErrNotEnabled) { + continue + } + return err + } + if len(newsubs) != 0 { + // Purge subscription list as there will be conflicts + w.Connections[w.ConnectionManager[x].Connection].Subscriptions.Clear() + err = w.SubscribeToChannels(w.ConnectionManager[x].Connection, newsubs) + if err != nil { + return err + } + } } return nil } @@ -861,7 +938,10 @@ func (w *Websocket) GetName() string { // GetChannelDifference finds the difference between the subscribed channels // and the new subscription list when pairs are disabled or enabled. -func (w *Websocket) GetChannelDifference(newSubs subscription.List) (sub, unsub subscription.List) { +func (w *Websocket) GetChannelDifference(conn Connection, newSubs subscription.List) (sub, unsub subscription.List) { + if conn != nil { + return w.Connections[conn].Subscriptions.Diff(newSubs) + } if w.subscriptions == nil { w.subscriptions = subscription.NewStore() } @@ -869,8 +949,28 @@ func (w *Websocket) GetChannelDifference(newSubs subscription.List) (sub, unsub } // UnsubscribeChannels unsubscribes from a list of websocket channel -func (w *Websocket) UnsubscribeChannels(channels subscription.List) error { - if w.subscriptions == nil || len(channels) == 0 { +func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.List) error { + if len(channels) == 0 { + return nil // No channels to unsubscribe from is not an error + } + + if conn != nil { + store, ok := w.Connections[conn] + if !ok { + return errors.New("connection not found") + } + if store.Subscriptions == nil { + return nil // No channels to unsubscribe from is not an error + } + for _, s := range channels { + if store.Subscriptions.Get(s) == nil { + return fmt.Errorf("%w: %s", subscription.ErrNotFound, s) + } + } + return store.Details.Unsubscriber(context.TODO(), conn, channels) + } + + if w.subscriptions == nil { return nil // No channels to unsubscribe from is not an error } for _, s := range channels { @@ -884,26 +984,35 @@ func (w *Websocket) UnsubscribeChannels(channels subscription.List) error { // ResubscribeToChannel resubscribes to channel // Sets state to Resubscribing, and exchanges which want to maintain a lock on it can respect this state and not RemoveSubscription // Errors if subscription is already subscribing -func (w *Websocket) ResubscribeToChannel(s *subscription.Subscription) error { +func (w *Websocket) ResubscribeToChannel(conn Connection, s *subscription.Subscription) error { l := subscription.List{s} if err := s.SetState(subscription.ResubscribingState); err != nil { return fmt.Errorf("%w: %s", err, s) } - if err := w.UnsubscribeChannels(l); err != nil { + if err := w.UnsubscribeChannels(conn, l); err != nil { return err } - return w.SubscribeToChannels(l) + return w.SubscribeToChannels(conn, l) } // SubscribeToChannels subscribes to websocket channels using the exchange specific Subscriber method // Errors are returned for duplicates or exceeding max Subscriptions -func (w *Websocket) SubscribeToChannels(subs subscription.List) error { +func (w *Websocket) SubscribeToChannels(conn Connection, subs subscription.List) error { if slices.Contains(subs, nil) { return fmt.Errorf("%w: List parameter contains an nil element", common.ErrNilPointer) } - if err := w.checkSubscriptions(subs); err != nil { + if err := w.checkSubscriptions(conn, subs); err != nil { return err } + + if conn != nil { + state, ok := w.Connections[conn] + if !ok { + return errors.New("connection details not found") + } + return state.Details.Subscriber(context.TODO(), conn, subs) + } + if err := w.Subscriber(subs); err != nil { return fmt.Errorf("%w: %w", ErrSubscriptionFailure, err) } @@ -934,10 +1043,36 @@ func (w *Websocket) AddSubscriptions(subs ...*subscription.Subscription) error { } // AddSuccessfulSubscriptions marks subscriptions as subscribed and adds them to the subscription store -func (w *Websocket) AddSuccessfulSubscriptions(subs ...*subscription.Subscription) error { +func (w *Websocket) AddSuccessfulSubscriptions(conn Connection, subs ...*subscription.Subscription) error { if w == nil { return fmt.Errorf("%w: AddSuccessfulSubscriptions called on nil Websocket", common.ErrNilPointer) } + + if conn != nil { + state, ok := w.Connections[conn] + if !ok { + for k, v := range w.Connections { + fmt.Printf("key: %v, value: %v\n", k, v) + } + + fmt.Println("conn", conn) + + os.Exit(1) + + return errors.New("connection details not found") + } + var errs error + for _, s := range subs { + if err := s.SetState(subscription.SubscribedState); err != nil { + errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s)) + } + if err := state.Subscriptions.Add(s); err != nil { + errs = common.AppendError(errs, err) + } + } + return errs + } + if w.subscriptions == nil { w.subscriptions = subscription.NewStore() } @@ -954,10 +1089,28 @@ func (w *Websocket) AddSuccessfulSubscriptions(subs ...*subscription.Subscriptio } // RemoveSubscriptions removes subscriptions from the subscription list and sets the status to Unsubscribed -func (w *Websocket) RemoveSubscriptions(subs ...*subscription.Subscription) error { +func (w *Websocket) RemoveSubscriptions(conn Connection, subs ...*subscription.Subscription) error { if w == nil { return fmt.Errorf("%w: RemoveSubscriptions called on nil Websocket", common.ErrNilPointer) } + + if conn != nil { + state, ok := w.Connections[conn] + if !ok { + return errors.New("connection details not found") + } + var errs error + for _, s := range subs { + if err := s.SetState(subscription.UnsubscribedState); err != nil { + errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s)) + } + if err := state.Subscriptions.Remove(s); err != nil { + errs = common.AppendError(errs, err) + } + } + return errs + } + if w.subscriptions == nil { return fmt.Errorf("%w: RemoveSubscriptions called on uninitialised Websocket", common.ErrNilPointer) } @@ -1026,7 +1179,35 @@ func checkWebsocketURL(s string) error { // checkSubscriptions checks subscriptions against the max subscription limit and if the subscription already exists // The subscription state is not considered when counting existing subscriptions -func (w *Websocket) checkSubscriptions(subs subscription.List) error { +func (w *Websocket) checkSubscriptions(conn Connection, subs subscription.List) error { + if conn != nil { + state, ok := w.Connections[conn] + if !ok { + return errors.New("connection ddetails not found") + } + + if state.Subscriptions == nil { + return fmt.Errorf("%w: Websocket.subscriptions", common.ErrNilPointer) + } + + existing := state.Subscriptions.Len() + if w.MaxSubscriptionsPerConnection > 0 && existing+len(subs) > w.MaxSubscriptionsPerConnection { + return fmt.Errorf("%w: current subscriptions: %v, incoming subscriptions: %v, max subscriptions per connection: %v - please reduce enabled pairs", + errSubscriptionsExceedsLimit, + existing, + len(subs), + w.MaxSubscriptionsPerConnection) + } + + for _, s := range subs { + if found := state.Subscriptions.Get(s); found != nil { + return fmt.Errorf("%w: %s", subscription.ErrDuplicate, s) + } + } + + return nil + } + if w.subscriptions == nil { return fmt.Errorf("%w: Websocket.subscriptions", common.ErrNilPointer) } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index c79671a3762..2f22d69e9ed 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -341,54 +341,54 @@ func TestConnectionMessageErrors(t *testing.T) { err = ws.Connect() assert.ErrorIs(t, err, errNoPendingConnections, "Connect should error correctly") - ws.PendingConnections = []ConnectionSetup{{URL: "ws://localhost:8080/ws"}} + ws.ConnectionManager = []ConnectionDetails{{Details: &ConnectionSetup{URL: "ws://localhost:8080/ws"}}} err = ws.Connect() require.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset) - ws.PendingConnections[0].GenerateSubscriptions = func() (subscription.List, error) { + ws.ConnectionManager[0].Details.GenerateSubscriptions = func() (subscription.List, error) { return nil, errDastardlyReason } err = ws.Connect() require.ErrorIs(t, err, errDastardlyReason) - ws.PendingConnections[0].GenerateSubscriptions = func() (subscription.List, error) { + ws.ConnectionManager[0].Details.GenerateSubscriptions = func() (subscription.List, error) { return subscription.List{{}}, nil } err = ws.Connect() require.ErrorIs(t, err, errNoConnectFunc) - ws.PendingConnections[0].Connector = func(context.Context, Connection) error { + ws.ConnectionManager[0].Details.Connector = func(context.Context, Connection) error { return errDastardlyReason } err = ws.Connect() require.ErrorIs(t, err, errWebsocketDataHandlerUnset) - ws.PendingConnections[0].Handler = func(context.Context, []byte) error { + ws.ConnectionManager[0].Details.Handler = func(context.Context, []byte) error { return errDastardlyReason } err = ws.Connect() require.ErrorIs(t, err, errWebsocketSubscriberUnset) - ws.PendingConnections[0].Subscriber = func(context.Context, Connection, subscription.List) error { + ws.ConnectionManager[0].Details.Subscriber = func(context.Context, Connection, subscription.List) error { return errDastardlyReason } err = ws.Connect() require.ErrorIs(t, err, errDastardlyReason) - ws.PendingConnections[0].Connector = func(ctx context.Context, conn Connection) error { + ws.ConnectionManager[0].Details.Connector = func(ctx context.Context, conn Connection) error { return nil } err = ws.Connect() require.ErrorIs(t, err, errDastardlyReason) - ws.PendingConnections[0].Handler = func(context.Context, []byte) error { + ws.ConnectionManager[0].Details.Handler = func(context.Context, []byte) error { return nil } require.NoError(t, ws.Shutdown()) err = ws.Connect() require.ErrorIs(t, err, errDastardlyReason) - ws.PendingConnections[0].Subscriber = func(context.Context, Connection, subscription.List) error { + ws.ConnectionManager[0].Details.Subscriber = func(context.Context, Connection, subscription.List) error { return nil } require.NoError(t, ws.Shutdown()) @@ -493,13 +493,13 @@ func TestWebsocket(t *testing.T) { func currySimpleSub(w *Websocket) func(subscription.List) error { return func(subs subscription.List) error { - return w.AddSuccessfulSubscriptions(subs...) + return w.AddSuccessfulSubscriptions(nil, subs...) } } func currySimpleUnsub(w *Websocket) func(subscription.List) error { return func(unsubs subscription.List) error { - return w.RemoveSubscriptions(unsubs...) + return w.RemoveSubscriptions(nil, unsubs...) } } @@ -514,11 +514,11 @@ func TestSubscribeUnsubscribe(t *testing.T) { subs, err := ws.GenerateSubs() require.NoError(t, err, "Generating test subscriptions should not error") - assert.NoError(t, new(Websocket).UnsubscribeChannels(subs), "Should not error when w.subscriptions is nil") - assert.NoError(t, ws.UnsubscribeChannels(nil), "Unsubscribing from nil should not error") - assert.ErrorIs(t, ws.UnsubscribeChannels(subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed") + assert.NoError(t, new(Websocket).UnsubscribeChannels(nil, subs), "Should not error when w.subscriptions is nil") + assert.NoError(t, ws.UnsubscribeChannels(nil, nil), "Unsubscribing from nil should not error") + assert.ErrorIs(t, ws.UnsubscribeChannels(nil, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed") assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return") - assert.NoError(t, ws.SubscribeToChannels(subs), "Basic Subscribing should not error") + assert.NoError(t, ws.SubscribeToChannels(nil, subs), "Basic Subscribing should not error") assert.Len(t, ws.GetSubscriptions(), 4, "Should have 4 subscriptions") bySub := ws.GetSubscription(subscription.Subscription{Channel: "TestSub"}) if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") { @@ -536,14 +536,14 @@ func TestSubscribeUnsubscribe(t *testing.T) { } assert.Nil(t, ws.GetSubscription(nil), "GetSubscription by nil should return nil") assert.Nil(t, ws.GetSubscription(45), "GetSubscription by invalid key should return nil") - assert.ErrorIs(t, ws.SubscribeToChannels(subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed") - assert.NoError(t, ws.SubscribeToChannels(nil), "Subscribe to an nil List should not error") - assert.NoError(t, ws.UnsubscribeChannels(subs), "Unsubscribing should not error") + assert.ErrorIs(t, ws.SubscribeToChannels(nil, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed") + assert.NoError(t, ws.SubscribeToChannels(nil, nil), "Subscribe to an nil List should not error") + assert.NoError(t, ws.UnsubscribeChannels(nil, subs), "Unsubscribing should not error") ws.Subscriber = func(subscription.List) error { return errDastardlyReason } - assert.ErrorIs(t, ws.SubscribeToChannels(subs), errDastardlyReason, "Should error correctly when error returned from Subscriber") + assert.ErrorIs(t, ws.SubscribeToChannels(nil, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber") - err = ws.SubscribeToChannels(subscription.List{nil}) + err = ws.SubscribeToChannels(nil, subscription.List{nil}) assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription") } @@ -565,9 +565,9 @@ func TestResubscribe(t *testing.T) { channel := subscription.List{{Channel: "resubTest"}} - assert.ErrorIs(t, ws.ResubscribeToChannel(channel[0]), subscription.ErrNotFound, "Resubscribe should error when channel isn't subscribed yet") - assert.NoError(t, ws.SubscribeToChannels(channel), "Subscribe should not error") - assert.NoError(t, ws.ResubscribeToChannel(channel[0]), "Resubscribe should not error now the channel is subscribed") + assert.ErrorIs(t, ws.ResubscribeToChannel(nil, channel[0]), subscription.ErrNotFound, "Resubscribe should error when channel isn't subscribed yet") + assert.NoError(t, ws.SubscribeToChannels(nil, channel), "Subscribe should not error") + assert.NoError(t, ws.ResubscribeToChannel(nil, channel[0]), "Resubscribe should not error now the channel is subscribed") } // TestSubscriptions tests adding, getting and removing subscriptions @@ -581,7 +581,7 @@ func TestSubscriptions(t *testing.T) { assert.ErrorIs(t, w.AddSubscriptions(s), subscription.ErrDuplicate, "Adding same subscription should return error") assert.Equal(t, subscription.SubscribingState, s.State(), "Should set state to Subscribing") - err := w.RemoveSubscriptions(s) + err := w.RemoveSubscriptions(nil, s) require.NoError(t, err, "RemoveSubscriptions must not error") assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub") assert.Equal(t, subscription.UnsubscribedState, s.State(), "Should set state to Unsubscribed") @@ -595,21 +595,21 @@ func TestSubscriptions(t *testing.T) { func TestSuccessfulSubscriptions(t *testing.T) { t.Parallel() w := new(Websocket) // Do not use NewWebsocket; We want to exercise w.subs == nil - assert.ErrorIs(t, (*Websocket)(nil).AddSuccessfulSubscriptions(nil), common.ErrNilPointer, "Should error correctly when nil websocket") + assert.ErrorIs(t, (*Websocket)(nil).AddSuccessfulSubscriptions(nil, nil), common.ErrNilPointer, "Should error correctly when nil websocket") c := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel} - require.NoError(t, w.AddSuccessfulSubscriptions(c), "Adding first subscription should not error") + require.NoError(t, w.AddSuccessfulSubscriptions(nil, c), "Adding first subscription should not error") assert.Same(t, c, w.GetSubscription(42), "Get Subscription should retrieve the same subscription") - assert.ErrorIs(t, w.AddSuccessfulSubscriptions(c), subscription.ErrInStateAlready, "Adding subscription in same state should return error") + assert.ErrorIs(t, w.AddSuccessfulSubscriptions(nil, c), subscription.ErrInStateAlready, "Adding subscription in same state should return error") require.NoError(t, c.SetState(subscription.SubscribingState), "SetState must not error") - assert.ErrorIs(t, w.AddSuccessfulSubscriptions(c), subscription.ErrDuplicate, "Adding same subscription should return error") + assert.ErrorIs(t, w.AddSuccessfulSubscriptions(nil, c), subscription.ErrDuplicate, "Adding same subscription should return error") - err := w.RemoveSubscriptions(c) + err := w.RemoveSubscriptions(nil, c) require.NoError(t, err, "RemoveSubscriptions must not error") assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub") - assert.ErrorIs(t, w.RemoveSubscriptions(c), subscription.ErrNotFound, "Should error correctly when not found") - assert.ErrorIs(t, (*Websocket)(nil).RemoveSubscriptions(nil), common.ErrNilPointer, "Should error correctly when nil websocket") + assert.ErrorIs(t, w.RemoveSubscriptions(nil, c), subscription.ErrNotFound, "Should error correctly when not found") + assert.ErrorIs(t, (*Websocket)(nil).RemoveSubscriptions(nil, nil), common.ErrNilPointer, "Should error correctly when nil websocket") w.subscriptions = nil - assert.ErrorIs(t, w.RemoveSubscriptions(c), common.ErrNilPointer, "Should error correctly when nil websocket") + assert.ErrorIs(t, w.RemoveSubscriptions(nil, c), common.ErrNilPointer, "Should error correctly when nil websocket") } // TestConnectionMonitorNoConnection logic test @@ -988,12 +988,12 @@ func TestGetChannelDifference(t *testing.T) { t.Parallel() w := &Websocket{} - assert.NotPanics(t, func() { w.GetChannelDifference(subscription.List{}) }, "Should not panic when called without a store") - subs, unsubs := w.GetChannelDifference(subscription.List{{Channel: subscription.CandlesChannel}}) + assert.NotPanics(t, func() { w.GetChannelDifference(nil, subscription.List{}) }, "Should not panic when called without a store") + subs, unsubs := w.GetChannelDifference(nil, subscription.List{{Channel: subscription.CandlesChannel}}) require.Equal(t, 1, len(subs), "Should get the correct number of subs") require.Empty(t, unsubs, "Should get no unsubs") require.NoError(t, w.AddSubscriptions(subs...), "AddSubscriptions must not error") - subs, unsubs = w.GetChannelDifference(subscription.List{{Channel: subscription.TickerChannel}}) + subs, unsubs = w.GetChannelDifference(nil, subscription.List{{Channel: subscription.TickerChannel}}) require.Equal(t, 1, len(subs), "Should get the correct number of subs") assert.Equal(t, 1, len(unsubs), "Should get the correct number of unsubs") } @@ -1239,21 +1239,21 @@ func TestLatency(t *testing.T) { func TestCheckSubscriptions(t *testing.T) { t.Parallel() ws := Websocket{} - err := ws.checkSubscriptions(nil) + err := ws.checkSubscriptions(nil, nil) assert.ErrorIs(t, err, common.ErrNilPointer, "checkSubscriptions should error correctly on nil w.subscriptions") assert.ErrorContains(t, err, "Websocket.subscriptions", "checkSubscriptions should error giving context correctly on nil w.subscriptions") ws.subscriptions = subscription.NewStore() - err = ws.checkSubscriptions(nil) + err = ws.checkSubscriptions(nil, nil) assert.NoError(t, err, "checkSubscriptions should not error on a nil list") ws.MaxSubscriptionsPerConnection = 1 - err = ws.checkSubscriptions(subscription.List{{}}) + err = ws.checkSubscriptions(nil, subscription.List{{}}) assert.NoError(t, err, "checkSubscriptions should not error when subscriptions is empty") ws.subscriptions = subscription.NewStore() - err = ws.checkSubscriptions(subscription.List{{}, {}}) + err = ws.checkSubscriptions(nil, subscription.List{{}, {}}) assert.ErrorIs(t, err, errSubscriptionsExceedsLimit, "checkSubscriptions should error correctly") ws.MaxSubscriptionsPerConnection = 2 @@ -1261,9 +1261,9 @@ func TestCheckSubscriptions(t *testing.T) { ws.subscriptions = subscription.NewStore() err = ws.subscriptions.Add(&subscription.Subscription{Key: 42, Channel: "test"}) require.NoError(t, err, "Add subscription must not error") - err = ws.checkSubscriptions(subscription.List{{Key: 42, Channel: "test"}}) + err = ws.checkSubscriptions(nil, subscription.List{{Key: 42, Channel: "test"}}) assert.ErrorIs(t, err, subscription.ErrDuplicate, "checkSubscriptions should error correctly") - err = ws.checkSubscriptions(subscription.List{{}}) + err = ws.checkSubscriptions(nil, subscription.List{{}}) assert.NoError(t, err, "checkSubscriptions should not error") } diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 66473d5721c..dab7883054f 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -29,6 +29,16 @@ const ( connectedState ) +type ConnectionAssociation struct { + Subscriptions *subscription.Store + Details *ConnectionSetup +} + +type ConnectionDetails struct { + Details *ConnectionSetup + Connection Connection +} + // Websocket defines a return type for websocket connections via the interface // wrapper for routine processing type Websocket struct { @@ -50,7 +60,8 @@ type Websocket struct { m sync.Mutex connector func() error - PendingConnections []ConnectionSetup + ConnectionManager []ConnectionDetails + Connections map[Connection]ConnectionAssociation subscriptions *subscription.Store From e1f2f7a5fdfb5579bddc9e167291f1f1c100b0bc Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Mon, 15 Jul 2024 15:49:19 +1000 Subject: [PATCH 005/138] some changes --- exchanges/stream/websocket.go | 12 +----------- exchanges/stream/websocket_types.go | 4 +++- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index c50184f4561..6781692efb6 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -6,7 +6,6 @@ import ( "fmt" "net" "net/url" - "os" "slices" "time" @@ -333,8 +332,7 @@ func (w *Websocket) Connect() error { return fmt.Errorf("cannot connect: %w", errNoPendingConnections) } - // TODO: Implement concurrency below. This can be achieved once there is - // more mutex protection around the subscriptions. + // TODO: Implement concurrency below. for i := range w.ConnectionManager { if w.ConnectionManager[i].Details.GenerateSubscriptions == nil { return fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.ConnectionManager[i].Details.URL, errWebsocketSubscriptionsGeneratorUnset) @@ -1051,14 +1049,6 @@ func (w *Websocket) AddSuccessfulSubscriptions(conn Connection, subs ...*subscri if conn != nil { state, ok := w.Connections[conn] if !ok { - for k, v := range w.Connections { - fmt.Printf("key: %v, value: %v\n", k, v) - } - - fmt.Println("conn", conn) - - os.Exit(1) - return errors.New("connection details not found") } var errs error diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index dab7883054f..e08c1dbffe0 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -29,14 +29,16 @@ const ( connectedState ) +// ConnectionAssociation contains the connection details and subscriptions type ConnectionAssociation struct { Subscriptions *subscription.Store Details *ConnectionSetup } +// ConnectionSetup contains the connection details and it's tracked connections type ConnectionDetails struct { Details *ConnectionSetup - Connection Connection + Connection Connection // TODO: Upgrade to slice of connections. } // Websocket defines a return type for websocket connections via the interface From 76524cc0ece073362c0f3e73466178175990d328 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Mon, 15 Jul 2024 16:15:15 +1000 Subject: [PATCH 006/138] linter: fixes strikes again. --- docs/ADD_NEW_EXCHANGE.md | 2 +- exchanges/binance/binance_wrapper.go | 2 +- exchanges/binanceus/binanceus_wrapper.go | 2 +- exchanges/bitfinex/bitfinex_wrapper.go | 4 ++-- exchanges/bithumb/bithumb_wrapper.go | 2 +- exchanges/bitmex/bitmex_wrapper.go | 2 +- exchanges/bitstamp/bitstamp_wrapper.go | 2 +- exchanges/btcmarkets/btcmarkets_wrapper.go | 2 +- exchanges/btse/btse_wrapper.go | 2 +- exchanges/bybit/bybit_wrapper.go | 4 ++-- exchanges/coinbasepro/coinbasepro_wrapper.go | 2 +- exchanges/coinut/coinut_wrapper.go | 2 +- exchanges/deribit/deribit_wrapper.go | 2 +- exchanges/gateio/gateio_websocket.go | 17 ++++++++++------- exchanges/gateio/gateio_wrapper.go | 10 +++++----- exchanges/gateio/gateio_ws_delivery_futures.go | 8 ++++---- exchanges/gateio/gateio_ws_futures.go | 12 ++++++------ exchanges/gateio/gateio_ws_option.go | 10 +++++----- exchanges/gemini/gemini_wrapper.go | 4 ++-- exchanges/hitbtc/hitbtc_wrapper.go | 2 +- exchanges/huobi/huobi_wrapper.go | 4 ++-- exchanges/kraken/kraken_wrapper.go | 4 ++-- exchanges/kucoin/kucoin_wrapper.go | 2 +- exchanges/okcoin/okcoin_wrapper.go | 4 ++-- exchanges/okx/okx_wrapper.go | 4 ++-- exchanges/poloniex/poloniex_wrapper.go | 2 +- exchanges/stream/websocket.go | 13 ++++++------- exchanges/stream/websocket_test.go | 14 +++++++------- exchanges/stream/websocket_types.go | 2 +- 29 files changed, 72 insertions(+), 70 deletions(-) diff --git a/docs/ADD_NEW_EXCHANGE.md b/docs/ADD_NEW_EXCHANGE.md index 834a8c3107b..16699b914e2 100644 --- a/docs/ADD_NEW_EXCHANGE.md +++ b/docs/ADD_NEW_EXCHANGE.md @@ -1137,7 +1137,7 @@ func (f *FTX) Setup(exch *config.Exchange) error { return err } // Sets up a new connection for the websocket, there are two separate connections denoted by the ConnectionSetup struct auth bool. - return f.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return f.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, // RateLimit int64 rudimentary rate limit that sleeps connection in milliseconds before sending designated payload diff --git a/exchanges/binance/binance_wrapper.go b/exchanges/binance/binance_wrapper.go index 951de41998f..41dde2dc85a 100644 --- a/exchanges/binance/binance_wrapper.go +++ b/exchanges/binance/binance_wrapper.go @@ -255,7 +255,7 @@ func (b *Binance) Setup(exch *config.Exchange) error { return err } - return b.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, RateLimit: wsRateLimitMilliseconds, diff --git a/exchanges/binanceus/binanceus_wrapper.go b/exchanges/binanceus/binanceus_wrapper.go index 8314c22dd0d..e8ddfeb2a97 100644 --- a/exchanges/binanceus/binanceus_wrapper.go +++ b/exchanges/binanceus/binanceus_wrapper.go @@ -185,7 +185,7 @@ func (bi *Binanceus) Setup(exch *config.Exchange) error { return err } - return bi.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return bi.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, RateLimit: wsRateLimitMilliseconds, diff --git a/exchanges/bitfinex/bitfinex_wrapper.go b/exchanges/bitfinex/bitfinex_wrapper.go index 8258de9ddcc..9dc138ab9f7 100644 --- a/exchanges/bitfinex/bitfinex_wrapper.go +++ b/exchanges/bitfinex/bitfinex_wrapper.go @@ -218,7 +218,7 @@ func (b *Bitfinex) Setup(exch *config.Exchange) error { return err } - err = b.Websocket.SetupNewConnection(stream.ConnectionSetup{ + err = b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, URL: publicBitfinexWebsocketEndpoint, @@ -227,7 +227,7 @@ func (b *Bitfinex) Setup(exch *config.Exchange) error { return err } - return b.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, URL: authenticatedBitfinexWebsocketEndpoint, diff --git a/exchanges/bithumb/bithumb_wrapper.go b/exchanges/bithumb/bithumb_wrapper.go index 8b30aee0714..55be0766e09 100644 --- a/exchanges/bithumb/bithumb_wrapper.go +++ b/exchanges/bithumb/bithumb_wrapper.go @@ -169,7 +169,7 @@ func (b *Bithumb) Setup(exch *config.Exchange) error { return err } - return b.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, RateLimit: wsRateLimitMillisecond, diff --git a/exchanges/bitmex/bitmex_wrapper.go b/exchanges/bitmex/bitmex_wrapper.go index 0d20a361665..78a13945fe3 100644 --- a/exchanges/bitmex/bitmex_wrapper.go +++ b/exchanges/bitmex/bitmex_wrapper.go @@ -208,7 +208,7 @@ func (b *Bitmex) Setup(exch *config.Exchange) error { if err != nil { return err } - return b.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, URL: bitmexWSURL, diff --git a/exchanges/bitstamp/bitstamp_wrapper.go b/exchanges/bitstamp/bitstamp_wrapper.go index 5ce0156d61d..f8ff0632fa4 100644 --- a/exchanges/bitstamp/bitstamp_wrapper.go +++ b/exchanges/bitstamp/bitstamp_wrapper.go @@ -163,7 +163,7 @@ func (b *Bitstamp) Setup(exch *config.Exchange) error { return err } - return b.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: b.Websocket.GetWebsocketURL(), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, diff --git a/exchanges/btcmarkets/btcmarkets_wrapper.go b/exchanges/btcmarkets/btcmarkets_wrapper.go index ece27598300..db2596b3d7e 100644 --- a/exchanges/btcmarkets/btcmarkets_wrapper.go +++ b/exchanges/btcmarkets/btcmarkets_wrapper.go @@ -172,7 +172,7 @@ func (b *BTCMarkets) Setup(exch *config.Exchange) error { return err } - return b.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, }) diff --git a/exchanges/btse/btse_wrapper.go b/exchanges/btse/btse_wrapper.go index 27df6de1147..3b897a4a7f9 100644 --- a/exchanges/btse/btse_wrapper.go +++ b/exchanges/btse/btse_wrapper.go @@ -197,7 +197,7 @@ func (b *BTSE) Setup(exch *config.Exchange) error { return err } - return b.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, }) diff --git a/exchanges/bybit/bybit_wrapper.go b/exchanges/bybit/bybit_wrapper.go index 66922e259b4..8d328d756a7 100644 --- a/exchanges/bybit/bybit_wrapper.go +++ b/exchanges/bybit/bybit_wrapper.go @@ -226,7 +226,7 @@ func (by *Bybit) Setup(exch *config.Exchange) error { if err != nil { return err } - err = by.Websocket.SetupNewConnection(stream.ConnectionSetup{ + err = by.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: by.Websocket.GetWebsocketURL(), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: bybitWebsocketTimer, @@ -235,7 +235,7 @@ func (by *Bybit) Setup(exch *config.Exchange) error { return err } - return by.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return by.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: websocketPrivate, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, diff --git a/exchanges/coinbasepro/coinbasepro_wrapper.go b/exchanges/coinbasepro/coinbasepro_wrapper.go index 81a64ccfd87..d5d9173db78 100644 --- a/exchanges/coinbasepro/coinbasepro_wrapper.go +++ b/exchanges/coinbasepro/coinbasepro_wrapper.go @@ -174,7 +174,7 @@ func (c *CoinbasePro) Setup(exch *config.Exchange) error { return err } - return c.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return c.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, }) diff --git a/exchanges/coinut/coinut_wrapper.go b/exchanges/coinut/coinut_wrapper.go index 2a8fd264a89..4936d5d5ae1 100644 --- a/exchanges/coinut/coinut_wrapper.go +++ b/exchanges/coinut/coinut_wrapper.go @@ -148,7 +148,7 @@ func (c *COINUT) Setup(exch *config.Exchange) error { return err } - return c.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return c.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, RateLimit: wsRateLimitInMilliseconds, diff --git a/exchanges/deribit/deribit_wrapper.go b/exchanges/deribit/deribit_wrapper.go index abe7c67f0da..38023875caa 100644 --- a/exchanges/deribit/deribit_wrapper.go +++ b/exchanges/deribit/deribit_wrapper.go @@ -211,7 +211,7 @@ func (d *Deribit) Setup(exch *config.Exchange) error { // setup option decimal regex at startup to make constant checks more efficient optionRegex = regexp.MustCompile(optionDecimalRegex) - return d.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return d.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: d.Websocket.GetWebsocketURL(), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 1ed473f5dd2..f03b98f7682 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -47,6 +47,9 @@ const ( spotFundingBalanceChannel = "spot.funding_balances" crossMarginBalanceChannel = "spot.cross_balances" crossMarginLoanChannel = "spot.cross_loan" + + subscribeEvent = "subscribe" + unsubscribeEvent = "unsubscribe" ) var defaultSubscriptions = []string{ @@ -57,7 +60,7 @@ var defaultSubscriptions = []string{ var fetchedCurrencyPairSnapshotOrderbook = make(map[string]bool) -// WsConnect initiates a websocket connection +// WsConnectSpot initiates a websocket connection func (g *Gateio) WsConnectSpot(ctx context.Context, conn stream.Connection) error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { return stream.ErrWebsocketNotEnabled @@ -93,14 +96,14 @@ func (g *Gateio) generateWsSignature(secret, event, channel string, dtime time.T } // WsHandleSpotData handles spot data -func (g *Gateio) WsHandleSpotData(ctx context.Context, respRaw []byte) error { +func (g *Gateio) WsHandleSpotData(_ context.Context, respRaw []byte) error { var push WsResponse err := json.Unmarshal(respRaw, &push) if err != nil { return err } - if push.Event == "subscribe" || push.Event == "unsubscribe" { + if push.Event == subscribeEvent || push.Event == unsubscribeEvent { if !g.Websocket.Match.IncomingWithData(push.ID, respRaw) { return fmt.Errorf("couldn't match subscription message with ID: %d", push.ID) } @@ -608,7 +611,7 @@ func (g *Gateio) processCrossMarginLoans(data []byte) error { return nil } -// GenerateDefaultSubscriptions returns default subscriptions +// GenerateDefaultSubscriptionsSpot returns default subscriptions func (g *Gateio) GenerateDefaultSubscriptionsSpot() (subscription.List, error) { channelsToSubscribe := defaultSubscriptions if g.Websocket.CanUseAuthenticatedEndpoints() { @@ -696,7 +699,7 @@ func (g *Gateio) handleSubscription(ctx context.Context, conn stream.Connection, errs = common.AppendError(errs, fmt.Errorf("error while %s to channel %s error code: %d message: %s", payloads[k].Event, payloads[k].Channel, resp.Error.Code, resp.Error.Message)) continue } - if payloads[k].Event == "subscribe" { + if payloads[k].Event == subscribeEvent { err = g.Websocket.AddSuccessfulSubscriptions(conn, channelsToSubscribe[k]) } else { err = g.Websocket.RemoveSubscriptions(conn, channelsToSubscribe[k]) @@ -830,12 +833,12 @@ func (g *Gateio) generatePayload(ctx context.Context, conn stream.Connection, ev // SpotSubscribe sends a websocket message to stop receiving data from the channel func (g *Gateio) SpotSubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { - return g.handleSubscription(ctx, conn, "subscribe", channelsToUnsubscribe) + return g.handleSubscription(ctx, conn, subscribeEvent, channelsToUnsubscribe) } // SpotUnsubscribe sends a websocket message to stop receiving data from the channel func (g *Gateio) SpotUnsubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { - return g.handleSubscription(ctx, conn, "unsubscribe", channelsToUnsubscribe) + return g.handleSubscription(ctx, conn, unsubscribeEvent, channelsToUnsubscribe) } func (g *Gateio) listOfAssetsCurrencyPairEnabledFor(cp currency.Pair) map[asset.Item]bool { diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index cdeeb96ff90..036bf50a9e7 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -209,7 +209,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { return err } // Spot connection - err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: gateioWebsocketEndpoint, RateLimit: gateioWebsocketRateLimit, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, @@ -224,7 +224,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { return err } // Futures connection - USDT margined - err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: futuresWebsocketUsdtURL, RateLimit: gateioWebsocketRateLimit, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, @@ -242,7 +242,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { } // Futures connection - BTC margined - err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: futuresWebsocketBtcURL, RateLimit: gateioWebsocketRateLimit, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, @@ -261,7 +261,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { // TODO: Add BTC margined delivery futures. // Futures connection - Delivery - USDT margined - err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: deliveryRealUSDTTradingURL, RateLimit: gateioWebsocketRateLimit, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, @@ -279,7 +279,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { } // Futures connection - Options - return g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: optionsWebsocketURL, RateLimit: gateioWebsocketRateLimit, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_ws_delivery_futures.go index 6d0dfb71248..da0dc2130b8 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_ws_delivery_futures.go @@ -116,12 +116,12 @@ func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() (subscription.Lis // DeliveryFuturesSubscribe sends a websocket message to stop receiving data from the channel func (g *Gateio) DeliveryFuturesSubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { - return g.handleDeliveryFuturesSubscription(ctx, conn, "subscribe", channelsToUnsubscribe) + return g.handleDeliveryFuturesSubscription(ctx, conn, subscribeEvent, channelsToUnsubscribe) } // DeliveryFuturesUnsubscribe sends a websocket message to stop receiving data from the channel func (g *Gateio) DeliveryFuturesUnsubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { - return g.handleDeliveryFuturesSubscription(ctx, conn, "unsubscribe", channelsToUnsubscribe) + return g.handleDeliveryFuturesSubscription(ctx, conn, unsubscribeEvent, channelsToUnsubscribe) } // handleDeliveryFuturesSubscription sends a websocket message to receive data from the channel @@ -146,7 +146,7 @@ func (g *Gateio) handleDeliveryFuturesSubscription(ctx context.Context, conn str errs = common.AppendError(errs, fmt.Errorf("error while %s to channel %s error code: %d message: %s", val.Event, val.Channel, resp.Error.Code, resp.Error.Message)) continue } - if val.Event == "subscribe" { + if val.Event == subscribeEvent { err = g.Websocket.AddSuccessfulSubscriptions(conn, channelsToSubscribe[i]) } else { err = g.Websocket.RemoveSubscriptions(conn, channelsToSubscribe[i]) @@ -171,7 +171,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(ctx context.Context, conn stream g.Websocket.SetCanUseAuthenticatedEndpoints(false) } } - var outbound []WsInput + outbound := make([]WsInput, 0, len(channelsToSubscribe)) for i := range channelsToSubscribe { if len(channelsToSubscribe[i].Pairs) != 1 { return nil, subscription.ErrNotSinglePair diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index 5c44093f283..561a1e73c6d 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -159,23 +159,23 @@ func (g *Gateio) GenerateFuturesDefaultSubscriptions(settlement currency.Code) ( // FuturesSubscribe sends a websocket message to stop receiving data from the channel func (g *Gateio) FuturesSubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { - return g.handleFuturesSubscription(ctx, conn, "subscribe", channelsToUnsubscribe) + return g.handleFuturesSubscription(ctx, conn, subscribeEvent, channelsToUnsubscribe) } // FuturesUnsubscribe sends a websocket message to stop receiving data from the channel func (g *Gateio) FuturesUnsubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { - return g.handleFuturesSubscription(ctx, conn, "unsubscribe", channelsToUnsubscribe) + return g.handleFuturesSubscription(ctx, conn, unsubscribeEvent, channelsToUnsubscribe) } // WsHandleFuturesData handles futures websocket data -func (g *Gateio) WsHandleFuturesData(ctx context.Context, respRaw []byte, a asset.Item) error { +func (g *Gateio) WsHandleFuturesData(_ context.Context, respRaw []byte, a asset.Item) error { var push WsResponse err := json.Unmarshal(respRaw, &push) if err != nil { return err } - if push.Event == "subscribe" || push.Event == "unsubscribe" { + if push.Event == subscribeEvent || push.Event == unsubscribeEvent { if !g.Websocket.Match.IncomingWithData(push.ID, respRaw) { return fmt.Errorf("couldn't match subscription message with ID: %d", push.ID) } @@ -249,7 +249,7 @@ func (g *Gateio) handleFuturesSubscription(ctx context.Context, conn stream.Conn errs = common.AppendError(errs, fmt.Errorf("error while %s to channel %s error code: %d message: %s", val.Event, val.Channel, resp.Error.Code, resp.Error.Message)) continue } - if val.Event == "subscribe" { + if val.Event == subscribeEvent { err = g.Websocket.AddSuccessfulSubscriptions(conn, channelsToSubscribe[i]) } else { err = g.Websocket.RemoveSubscriptions(conn, channelsToSubscribe[i]) @@ -278,7 +278,7 @@ func (g *Gateio) generateFuturesPayload(ctx context.Context, conn stream.Connect } } - var outbound []WsInput + outbound := make([]WsInput, 0, len(channelsToSubscribe)) for i := range channelsToSubscribe { if len(channelsToSubscribe[i].Pairs) != 1 { return nil, subscription.ErrNotSinglePair diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_ws_option.go index 51a396b8427..9524cd6d34d 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_ws_option.go @@ -286,12 +286,12 @@ func (g *Gateio) generateOptionsPayload(ctx context.Context, conn stream.Connect // OptionsSubscribe sends a websocket message to stop receiving data for asset type options func (g *Gateio) OptionsSubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { - return g.handleOptionsSubscription(ctx, conn, "subscribe", channelsToUnsubscribe) + return g.handleOptionsSubscription(ctx, conn, subscribeEvent, channelsToUnsubscribe) } // OptionsUnsubscribe sends a websocket message to stop receiving data for asset type options func (g *Gateio) OptionsUnsubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { - return g.handleOptionsSubscription(ctx, conn, "unsubscribe", channelsToUnsubscribe) + return g.handleOptionsSubscription(ctx, conn, unsubscribeEvent, channelsToUnsubscribe) } // handleOptionsSubscription sends a websocket message to receive data from the channel @@ -315,7 +315,7 @@ func (g *Gateio) handleOptionsSubscription(ctx context.Context, conn stream.Conn errs = common.AppendError(errs, fmt.Errorf("error while %s to channel %s asset type: options error code: %d message: %s", payloads[k].Event, payloads[k].Channel, resp.Error.Code, resp.Error.Message)) continue } - if payloads[k].Event == "subscribe" { + if payloads[k].Event == subscribeEvent { err = g.Websocket.AddSuccessfulSubscriptions(conn, channelsToSubscribe[k]) } else { err = g.Websocket.RemoveSubscriptions(conn, channelsToSubscribe[k]) @@ -329,14 +329,14 @@ func (g *Gateio) handleOptionsSubscription(ctx context.Context, conn stream.Conn } // WsHandleOptionsData handles options websocket data -func (g *Gateio) WsHandleOptionsData(ctx context.Context, respRaw []byte) error { +func (g *Gateio) WsHandleOptionsData(_ context.Context, respRaw []byte) error { var push WsResponse err := json.Unmarshal(respRaw, &push) if err != nil { return err } - if push.Event == "subscribe" || push.Event == "unsubscribe" { + if push.Event == subscribeEvent || push.Event == unsubscribeEvent { if !g.Websocket.Match.IncomingWithData(push.ID, respRaw) { return fmt.Errorf("couldn't match subscription message with ID: %d", push.ID) } diff --git a/exchanges/gemini/gemini_wrapper.go b/exchanges/gemini/gemini_wrapper.go index e97c9d93e58..1159d37638e 100644 --- a/exchanges/gemini/gemini_wrapper.go +++ b/exchanges/gemini/gemini_wrapper.go @@ -152,7 +152,7 @@ func (g *Gemini) Setup(exch *config.Exchange) error { return err } - err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, URL: geminiWebsocketEndpoint + "/v2/" + geminiWsMarketData, @@ -161,7 +161,7 @@ func (g *Gemini) Setup(exch *config.Exchange) error { return err } - return g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, URL: geminiWebsocketEndpoint + "/v1/" + geminiWsOrderEvents, diff --git a/exchanges/hitbtc/hitbtc_wrapper.go b/exchanges/hitbtc/hitbtc_wrapper.go index 3e1bfc6902d..338434428fc 100644 --- a/exchanges/hitbtc/hitbtc_wrapper.go +++ b/exchanges/hitbtc/hitbtc_wrapper.go @@ -168,7 +168,7 @@ func (h *HitBTC) Setup(exch *config.Exchange) error { return err } - return h.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return h.Websocket.SetupNewConnection(&stream.ConnectionSetup{ RateLimit: rateLimit, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, diff --git a/exchanges/huobi/huobi_wrapper.go b/exchanges/huobi/huobi_wrapper.go index 3c1002300c1..26ef82648a6 100644 --- a/exchanges/huobi/huobi_wrapper.go +++ b/exchanges/huobi/huobi_wrapper.go @@ -219,7 +219,7 @@ func (h *HUOBI) Setup(exch *config.Exchange) error { return err } - err = h.Websocket.SetupNewConnection(stream.ConnectionSetup{ + err = h.Websocket.SetupNewConnection(&stream.ConnectionSetup{ RateLimit: rateLimit, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, @@ -228,7 +228,7 @@ func (h *HUOBI) Setup(exch *config.Exchange) error { return err } - return h.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return h.Websocket.SetupNewConnection(&stream.ConnectionSetup{ RateLimit: rateLimit, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, diff --git a/exchanges/kraken/kraken_wrapper.go b/exchanges/kraken/kraken_wrapper.go index 397d1a52a3a..7a9b5663bd5 100644 --- a/exchanges/kraken/kraken_wrapper.go +++ b/exchanges/kraken/kraken_wrapper.go @@ -231,7 +231,7 @@ func (k *Kraken) Setup(exch *config.Exchange) error { return err } - err = k.Websocket.SetupNewConnection(stream.ConnectionSetup{ + err = k.Websocket.SetupNewConnection(&stream.ConnectionSetup{ RateLimit: krakenWsRateLimit, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, @@ -241,7 +241,7 @@ func (k *Kraken) Setup(exch *config.Exchange) error { return err } - return k.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return k.Websocket.SetupNewConnection(&stream.ConnectionSetup{ RateLimit: krakenWsRateLimit, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, diff --git a/exchanges/kucoin/kucoin_wrapper.go b/exchanges/kucoin/kucoin_wrapper.go index d7cb02996e8..da8a47d3a61 100644 --- a/exchanges/kucoin/kucoin_wrapper.go +++ b/exchanges/kucoin/kucoin_wrapper.go @@ -221,7 +221,7 @@ func (ku *Kucoin) Setup(exch *config.Exchange) error { if err != nil { return err } - return ku.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return ku.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, RateLimit: 500, diff --git a/exchanges/okcoin/okcoin_wrapper.go b/exchanges/okcoin/okcoin_wrapper.go index 0eeff738528..a231a5fd631 100644 --- a/exchanges/okcoin/okcoin_wrapper.go +++ b/exchanges/okcoin/okcoin_wrapper.go @@ -168,7 +168,7 @@ func (o *Okcoin) Setup(exch *config.Exchange) error { if err != nil { return err } - err = o.Websocket.SetupNewConnection(stream.ConnectionSetup{ + err = o.Websocket.SetupNewConnection(&stream.ConnectionSetup{ RateLimit: okcoinWsRateLimit, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, @@ -176,7 +176,7 @@ func (o *Okcoin) Setup(exch *config.Exchange) error { if err != nil { return err } - return o.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return o.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, URL: okcoinPrivateWebsocketURL, diff --git a/exchanges/okx/okx_wrapper.go b/exchanges/okx/okx_wrapper.go index d29b7285b9b..a8b8995744f 100644 --- a/exchanges/okx/okx_wrapper.go +++ b/exchanges/okx/okx_wrapper.go @@ -218,7 +218,7 @@ func (ok *Okx) Setup(exch *config.Exchange) error { go ok.WsResponseMultiplexer.Run() - if err := ok.Websocket.SetupNewConnection(stream.ConnectionSetup{ + if err := ok.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: okxAPIWebsocketPublicURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: okxWebsocketResponseMaxLimit, @@ -227,7 +227,7 @@ func (ok *Okx) Setup(exch *config.Exchange) error { return err } - return ok.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return ok.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: okxAPIWebsocketPrivateURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: okxWebsocketResponseMaxLimit, diff --git a/exchanges/poloniex/poloniex_wrapper.go b/exchanges/poloniex/poloniex_wrapper.go index b115e0b4220..2230bfc1dfc 100644 --- a/exchanges/poloniex/poloniex_wrapper.go +++ b/exchanges/poloniex/poloniex_wrapper.go @@ -180,7 +180,7 @@ func (p *Poloniex) Setup(exch *config.Exchange) error { return err } - return p.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return p.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, }) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 6781692efb6..caadfcc147d 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -188,7 +188,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { } // SetupNewConnection sets up an auth or unauth streaming connection -func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { +func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { if w == nil { return fmt.Errorf("%w: %w", errConnSetup, errWebsocketIsNil) } @@ -227,14 +227,14 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { if c.Connector == nil { return fmt.Errorf("%w: %w", errConnSetup, errWebsocketConnectorUnset) } - w.ConnectionManager = append(w.ConnectionManager, ConnectionDetails{Details: &c}) + w.ConnectionManager = append(w.ConnectionManager, ConnectionDetails{Details: c}) return nil } if c.Authenticated { - w.AuthConn = w.getConnectionFromSetup(&c) + w.AuthConn = w.getConnectionFromSetup(c) } else { - w.Conn = w.getConnectionFromSetup(&c) + w.Conn = w.getConnectionFromSetup(c) } return nil @@ -305,7 +305,7 @@ func (w *Websocket) Connect() error { w.setState(connectedState) if !w.IsConnectionMonitorRunning() { - err := w.connectionMonitor() + err = w.connectionMonitor() if err != nil { log.Errorf(log.WebsocketMgr, "%s cannot start websocket connection monitor %v", w.GetName(), err) } @@ -627,8 +627,7 @@ func (w *Websocket) FlushChannels() error { } subs, unsubs := w.GetChannelDifference(w.ConnectionManager[x].Connection, newsubs) if len(unsubs) != 0 && w.features.Unsubscribe { - - err := w.UnsubscribeChannels(w.ConnectionManager[x].Connection, unsubs) + err = w.UnsubscribeChannels(w.ConnectionManager[x].Connection, unsubs) if err != nil { return err } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 2f22d69e9ed..3f52d0312c6 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -375,7 +375,7 @@ func TestConnectionMessageErrors(t *testing.T) { err = ws.Connect() require.ErrorIs(t, err, errDastardlyReason) - ws.ConnectionManager[0].Details.Connector = func(ctx context.Context, conn Connection) error { + ws.ConnectionManager[0].Details.Connector = func(context.Context, Connection) error { return nil } err = ws.Connect() @@ -1144,19 +1144,19 @@ func TestEnable(t *testing.T) { func TestSetupNewConnection(t *testing.T) { t.Parallel() var nonsenseWebsock *Websocket - err := nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) + err := nonsenseWebsock.SetupNewConnection(&ConnectionSetup{URL: "urlstring"}) assert.ErrorIs(t, err, errWebsocketIsNil, "SetupNewConnection should error correctly") nonsenseWebsock = &Websocket{} - err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) + err = nonsenseWebsock.SetupNewConnection(&ConnectionSetup{URL: "urlstring"}) assert.ErrorIs(t, err, errExchangeConfigNameEmpty, "SetupNewConnection should error correctly") nonsenseWebsock = &Websocket{exchangeName: "test"} - err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) + err = nonsenseWebsock.SetupNewConnection(&ConnectionSetup{URL: "urlstring"}) assert.ErrorIs(t, err, errTrafficAlertNil, "SetupNewConnection should error correctly") nonsenseWebsock.TrafficAlert = make(chan struct{}, 1) - err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) + err = nonsenseWebsock.SetupNewConnection(&ConnectionSetup{URL: "urlstring"}) assert.ErrorIs(t, err, errReadMessageErrorsNil, "SetupNewConnection should error correctly") web := NewWebsocket() @@ -1164,10 +1164,10 @@ func TestSetupNewConnection(t *testing.T) { err = web.Setup(defaultSetup) assert.NoError(t, err, "Setup should not error") - err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) + err = web.SetupNewConnection(&ConnectionSetup{URL: "urlstring"}) assert.NoError(t, err, "SetupNewConnection should not error") - err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring", Authenticated: true}) + err = web.SetupNewConnection(&ConnectionSetup{URL: "urlstring", Authenticated: true}) assert.NoError(t, err, "SetupNewConnection should not error") } diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index e08c1dbffe0..4cf262de492 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -35,7 +35,7 @@ type ConnectionAssociation struct { Details *ConnectionSetup } -// ConnectionSetup contains the connection details and it's tracked connections +// ConnectionDetails contains the connection details and it's tracked connections type ConnectionDetails struct { Details *ConnectionSetup Connection Connection // TODO: Upgrade to slice of connections. From 640e82ec1ecad35f9df85ca914277d9b855ccde0 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Tue, 16 Jul 2024 09:06:43 +1000 Subject: [PATCH 007/138] Change name ConnectionAssociation -> ConnectionCandidate for better clarity on purpose. Change connections map to point to candidate to track subscriptions for future dynamic connections holder and drop struct ConnectionDetails. --- exchanges/stream/websocket.go | 17 +++++++++-------- exchanges/stream/websocket_test.go | 2 +- exchanges/stream/websocket_types.go | 29 ++++++++++++++++++----------- 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index caadfcc147d..cef1fc6b1a5 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -90,7 +90,7 @@ func NewWebsocket() *Websocket { subscriptions: subscription.NewStore(), features: &protocol.Features{}, Orderbook: buffer.Orderbook{}, - Connections: make(map[Connection]ConnectionAssociation), + Connections: make(map[Connection]*ConnectionCandidate), } } @@ -227,7 +227,10 @@ func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { if c.Connector == nil { return fmt.Errorf("%w: %w", errConnSetup, errWebsocketConnectorUnset) } - w.ConnectionManager = append(w.ConnectionManager, ConnectionDetails{Details: c}) + w.ConnectionManager = append(w.ConnectionManager, ConnectionCandidate{ + Details: c, + Subscriptions: subscription.NewStore(), + }) return nil } @@ -290,7 +293,6 @@ func (w *Websocket) Connect() error { } details.Subscriptions.Clear() } - w.Connections = make(map[Connection]ConnectionAssociation) w.dataMonitor() w.trafficMonitor() @@ -381,10 +383,7 @@ func (w *Websocket) Connect() error { w.Wg.Add(1) go w.Reader(context.TODO(), conn, w.ConnectionManager[i].Details.Handler) - w.Connections[conn] = ConnectionAssociation{ - Subscriptions: subscription.NewStore(), - Details: w.ConnectionManager[i].Details, - } + w.Connections[conn] = &w.ConnectionManager[i] err = w.ConnectionManager[i].Details.Subscriber(context.TODO(), conn, subs) if err != nil { @@ -557,7 +556,9 @@ func (w *Websocket) Shutdown() error { } details.Subscriptions.Clear() } - w.Connections = make(map[Connection]ConnectionAssociation) + + // Clean map of old connections + clear(w.Connections) if w.Conn != nil { if err := w.Conn.Shutdown(); err != nil { diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 3f52d0312c6..026f3eb2396 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -341,7 +341,7 @@ func TestConnectionMessageErrors(t *testing.T) { err = ws.Connect() assert.ErrorIs(t, err, errNoPendingConnections, "Connect should error correctly") - ws.ConnectionManager = []ConnectionDetails{{Details: &ConnectionSetup{URL: "ws://localhost:8080/ws"}}} + ws.ConnectionManager = []ConnectionCandidate{{Details: &ConnectionSetup{URL: "ws://localhost:8080/ws"}}} err = ws.Connect() require.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset) diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 4cf262de492..b88c55c5488 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -29,16 +29,19 @@ const ( connectedState ) -// ConnectionAssociation contains the connection details and subscriptions -type ConnectionAssociation struct { - Subscriptions *subscription.Store +// // ConnectionAssociation contains the connection details and subscriptions +// type ConnectionAssociation struct { +// Subscriptions *subscription.Store +// Details *ConnectionSetup +// } + +// ConnectionCandidate contains the connection setup details to be used when +// attempting a new connection. It also contains the subscriptions that are +// associated with the specifc connection. +type ConnectionCandidate struct { Details *ConnectionSetup -} - -// ConnectionDetails contains the connection details and it's tracked connections -type ConnectionDetails struct { - Details *ConnectionSetup - Connection Connection // TODO: Upgrade to slice of connections. + Subscriptions *subscription.Store + Connection Connection // TODO: Upgrade to slice of connections. } // Websocket defines a return type for websocket connections via the interface @@ -62,8 +65,12 @@ type Websocket struct { m sync.Mutex connector func() error - ConnectionManager []ConnectionDetails - Connections map[Connection]ConnectionAssociation + // ConnectionManager contains the connection candidates and the current + // connections + ConnectionManager []ConnectionCandidate + // Connections contains the current connections with their associated + // connection candidates + Connections map[Connection]*ConnectionCandidate subscriptions *subscription.Store From 3a0440d06e1b3149d0f8d841f69885dd00d122e9 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Tue, 16 Jul 2024 12:42:54 +1000 Subject: [PATCH 008/138] Add subscription tests (state functional) --- exchanges/gateio/gateio_wrapper.go | 7 - exchanges/stream/stream_types.go | 42 ++++- exchanges/stream/websocket.go | 251 ++++++++++++++-------------- exchanges/stream/websocket_test.go | 197 ++++++++++++++++++++-- exchanges/stream/websocket_types.go | 15 -- 5 files changed, 345 insertions(+), 167 deletions(-) diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 036bf50a9e7..1c996afe219 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -192,15 +192,8 @@ func (g *Gateio) Setup(exch *config.Exchange) error { return err } - wsRunningURL, err := g.API.Endpoints.GetURL(exchange.WebsocketSpot) - if err != nil { - return err - } - err = g.Websocket.Setup(&stream.WebsocketSetup{ ExchangeConfig: exch, - DefaultURL: gateioWebsocketEndpoint, - RunningURL: wsRunningURL, Features: &g.Features.Supports.WebsocketCapabilities, FillsFeed: g.Features.Enabled.FillsFeed, TradeFeed: g.Features.Enabled.TradeFeed, diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index 060dbb46f96..0cfa66a06ca 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -39,14 +39,44 @@ type ConnectionSetup struct { ResponseCheckTimeout time.Duration ResponseMaxLimit time.Duration RateLimit int64 - URL string Authenticated bool ConnectionLevelReporter Reporter - Handler func(ctx context.Context, incoming []byte) error - Subscriber func(ctx context.Context, conn Connection, sub subscription.List) error - Unsubscriber func(ctx context.Context, conn Connection, unsub subscription.List) error - GenerateSubscriptions func() (subscription.List, error) - Connector func(ctx context.Context, conn Connection) error + + // URL defines the websocket server URL to connect to + URL string + // Connector is the function that will be called to connect to the + // exchange's websocket server. This will be called once when the stream + // service is started. Any bespoke connection logic should be handled here. + Connector func(ctx context.Context, conn Connection) error + // GenerateSubscriptions is a function that will be called to generate a + // list of subscriptions to be made to the exchange's websocket server. + GenerateSubscriptions func() (subscription.List, error) + // Subscriber is a function that will be called to send subscription + // messages based on the exchange's websocket server requirements to + // subscribe to specific channels. + Subscriber func(ctx context.Context, conn Connection, sub subscription.List) error + // Unsubscriber is a function that will be called to send unsubscription + // messages based on the exchange's websocket server requirements to + // unsubscribe from specific channels. NOTE: IF THE FEATURE IS ENABLED. + Unsubscriber func(ctx context.Context, conn Connection, unsub subscription.List) error + // Handler defines the function that will be called when a message is + // received from the exchange's websocket server. This function should + // handle the incoming message and pass it to the appropriate data handler. + Handler func(ctx context.Context, incoming []byte) error +} + +// ConnectionCandidate contains the connection setup details to be used when +// attempting a new connection. It also contains the subscriptions that are +// associated with the specifc connection. +type ConnectionCandidate struct { + // Details contains the connection setup details + Details *ConnectionSetup + // Subscriptions contains the subscriptions that are associated with the + // specific connection(s) + Subscriptions *subscription.Store + // Connection contains the active connection based off the connection + // details above. + Connection Connection // TODO: Upgrade to slice of connections. } // PingHandler container for ping handler settings diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index cef1fc6b1a5..e569d53fd77 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -131,33 +131,51 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { } w.setEnabled(s.ExchangeConfig.Features.Enabled.Websocket) - w.connector = s.Connector - w.Subscriber = s.Subscriber - w.Unsubscriber = s.Unsubscriber - w.GenerateSubs = s.GenerateSubscriptions - - w.connectionMonitorDelay = s.ExchangeConfig.ConnectionMonitorDelay - if w.connectionMonitorDelay <= 0 { - w.connectionMonitorDelay = config.DefaultConnectionMonitorDelay - } + // If any fields here are set, assume that the previous global connector + // pattern is being used. + // TODO: Shift everything to connection setup when all exchanges are updated. + if s.Connector != nil || s.Subscriber != nil || s.Unsubscriber != nil || s.GenerateSubscriptions != nil || s.DefaultURL != "" || s.RunningURL != "" { + if s.Connector == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketConnectorUnset) + } + if s.Subscriber == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriberUnset) + } + if s.Unsubscriber == nil && w.features.Unsubscribe { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketUnsubscriberUnset) + } + if s.GenerateSubscriptions == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriptionsGeneratorUnset) + } + if s.DefaultURL == "" { + return fmt.Errorf("%s websocket %w", w.exchangeName, errDefaultURLIsEmpty) + } + w.defaultURL = s.DefaultURL + if s.RunningURL == "" { + return fmt.Errorf("%s websocket %w", w.exchangeName, errRunningURLIsEmpty) + } - if s.DefaultURL == "" { - return fmt.Errorf("%s websocket %w", w.exchangeName, errDefaultURLIsEmpty) - } - w.defaultURL = s.DefaultURL - if s.RunningURL == "" { - return fmt.Errorf("%s websocket %w", w.exchangeName, errRunningURLIsEmpty) - } - err := w.SetWebsocketURL(s.RunningURL, false, false) - if err != nil { - return fmt.Errorf("%s %w", w.exchangeName, err) - } + w.connector = s.Connector + w.Subscriber = s.Subscriber + w.Unsubscriber = s.Unsubscriber + w.GenerateSubs = s.GenerateSubscriptions - if s.RunningURLAuth != "" { - err = w.SetWebsocketURL(s.RunningURLAuth, true, false) + err := w.SetWebsocketURL(s.RunningURL, false, false) if err != nil { return fmt.Errorf("%s %w", w.exchangeName, err) } + + if s.RunningURLAuth != "" { + err = w.SetWebsocketURL(s.RunningURLAuth, true, false) + if err != nil { + return fmt.Errorf("%s %w", w.exchangeName, err) + } + } + } + + w.connectionMonitorDelay = s.ExchangeConfig.ConnectionMonitorDelay + if w.connectionMonitorDelay <= 0 { + w.connectionMonitorDelay = config.DefaultConnectionMonitorDelay } if s.ExchangeConfig.WebsocketTrafficTimeout < time.Second { @@ -212,8 +230,14 @@ func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { // functions are defined per connection. Else we use the global connector // and supporting functions for backwards compatibility. if w.connector == nil { - if c.Handler == nil { - return fmt.Errorf("%w: %w", errConnSetup, errWebsocketDataHandlerUnset) + if c.URL == "" { + return fmt.Errorf("%w: %w", errConnSetup, errDefaultURLIsEmpty) + } + if c.Connector == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketConnectorUnset) + } + if c.GenerateSubscriptions == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriptionsGeneratorUnset) } if c.Subscriber == nil { return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriberUnset) @@ -221,11 +245,8 @@ func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { if c.Unsubscriber == nil && w.features.Unsubscribe { return fmt.Errorf("%w: %w", errConnSetup, errWebsocketUnsubscriberUnset) } - if c.GenerateSubscriptions == nil { - return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriptionsGeneratorUnset) - } - if c.Connector == nil { - return fmt.Errorf("%w: %w", errConnSetup, errWebsocketConnectorUnset) + if c.Handler == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketDataHandlerUnset) } w.ConnectionManager = append(w.ConnectionManager, ConnectionCandidate{ Details: c, @@ -287,13 +308,6 @@ func (w *Websocket) Connect() error { } w.subscriptions.Clear() - for _, details := range w.Connections { - if details.Subscriptions == nil { - return fmt.Errorf("%w: subscriptions", common.ErrNilPointer) - } - details.Subscriptions.Clear() - } - w.dataMonitor() w.trafficMonitor() w.setState(connectingState) @@ -550,28 +564,29 @@ func (w *Websocket) Shutdown() error { defer w.Orderbook.FlushBuffer() - for conn, details := range w.Connections { + // Shutdown managed connections + for conn := range w.Connections { if err := conn.Shutdown(); err != nil { return err } - details.Subscriptions.Clear() } - // Clean map of old connections clear(w.Connections) + // Flush any subscriptions from last connection across any managed connections + for x := range w.ConnectionManager { + w.ConnectionManager[x].Subscriptions.Clear() + } if w.Conn != nil { if err := w.Conn.Shutdown(); err != nil { return err } } - if w.AuthConn != nil { if err := w.AuthConn.Shutdown(); err != nil { return err } } - // flush any subscriptions from last connection if needed w.subscriptions.Clear() @@ -937,13 +952,16 @@ func (w *Websocket) GetName() string { // GetChannelDifference finds the difference between the subscribed channels // and the new subscription list when pairs are disabled or enabled. func (w *Websocket) GetChannelDifference(conn Connection, newSubs subscription.List) (sub, unsub subscription.List) { - if conn != nil { - return w.Connections[conn].Subscriptions.Diff(newSubs) + var subscriptionStore **subscription.Store + if candidate, ok := w.Connections[conn]; ok { + subscriptionStore = &candidate.Subscriptions + } else { + subscriptionStore = &w.subscriptions } - if w.subscriptions == nil { - w.subscriptions = subscription.NewStore() + if *subscriptionStore == nil { + *subscriptionStore = subscription.NewStore() } - return w.subscriptions.Diff(newSubs) + return (*subscriptionStore).Diff(newSubs) } // UnsubscribeChannels unsubscribes from a list of websocket channel @@ -952,20 +970,16 @@ func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.L return nil // No channels to unsubscribe from is not an error } - if conn != nil { - store, ok := w.Connections[conn] - if !ok { - return errors.New("connection not found") - } - if store.Subscriptions == nil { + if candidate, ok := w.Connections[conn]; ok { + if candidate.Subscriptions == nil { return nil // No channels to unsubscribe from is not an error } for _, s := range channels { - if store.Subscriptions.Get(s) == nil { + if candidate.Subscriptions.Get(s) == nil { return fmt.Errorf("%w: %s", subscription.ErrNotFound, s) } } - return store.Details.Unsubscriber(context.TODO(), conn, channels) + return candidate.Details.Unsubscriber(context.TODO(), conn, channels) } if w.subscriptions == nil { @@ -1003,12 +1017,12 @@ func (w *Websocket) SubscribeToChannels(conn Connection, subs subscription.List) return err } - if conn != nil { - state, ok := w.Connections[conn] - if !ok { - return errors.New("connection details not found") - } - return state.Details.Subscriber(context.TODO(), conn, subs) + if candidate, ok := w.Connections[conn]; ok { + return candidate.Details.Subscriber(context.TODO(), conn, subs) + } + + if w.Subscriber == nil { + return fmt.Errorf("%w: Global Subscriber not set", common.ErrNilPointer) } if err := w.Subscriber(subs); err != nil { @@ -1046,32 +1060,23 @@ func (w *Websocket) AddSuccessfulSubscriptions(conn Connection, subs ...*subscri return fmt.Errorf("%w: AddSuccessfulSubscriptions called on nil Websocket", common.ErrNilPointer) } - if conn != nil { - state, ok := w.Connections[conn] - if !ok { - return errors.New("connection details not found") - } - var errs error - for _, s := range subs { - if err := s.SetState(subscription.SubscribedState); err != nil { - errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s)) - } - if err := state.Subscriptions.Add(s); err != nil { - errs = common.AppendError(errs, err) - } - } - return errs + var subscriptionStore **subscription.Store + if candidate, ok := w.Connections[conn]; ok { + subscriptionStore = &candidate.Subscriptions + } else { + subscriptionStore = &w.subscriptions } - if w.subscriptions == nil { - w.subscriptions = subscription.NewStore() + if *subscriptionStore == nil { + *subscriptionStore = subscription.NewStore() } + var errs error for _, s := range subs { if err := s.SetState(subscription.SubscribedState); err != nil { errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s)) } - if err := w.subscriptions.Add(s); err != nil { + if err := (*subscriptionStore).Add(s); err != nil { errs = common.AppendError(errs, err) } } @@ -1084,32 +1089,23 @@ func (w *Websocket) RemoveSubscriptions(conn Connection, subs ...*subscription.S return fmt.Errorf("%w: RemoveSubscriptions called on nil Websocket", common.ErrNilPointer) } - if conn != nil { - state, ok := w.Connections[conn] - if !ok { - return errors.New("connection details not found") - } - var errs error - for _, s := range subs { - if err := s.SetState(subscription.UnsubscribedState); err != nil { - errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s)) - } - if err := state.Subscriptions.Remove(s); err != nil { - errs = common.AppendError(errs, err) - } - } - return errs + var subscriptionStore *subscription.Store + if candidate, ok := w.Connections[conn]; ok { + subscriptionStore = candidate.Subscriptions + } else { + subscriptionStore = w.subscriptions } - if w.subscriptions == nil { + if subscriptionStore == nil { return fmt.Errorf("%w: RemoveSubscriptions called on uninitialised Websocket", common.ErrNilPointer) } + var errs error for _, s := range subs { if err := s.SetState(subscription.UnsubscribedState); err != nil { errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s)) } - if err := w.subscriptions.Remove(s); err != nil { + if err := subscriptionStore.Remove(s); err != nil { errs = common.AppendError(errs, err) } } @@ -1120,7 +1116,19 @@ func (w *Websocket) RemoveSubscriptions(conn Connection, subs ...*subscription.S // returns nil if no subscription is at that key or the key is nil // Keys can implement subscription.MatchableKey in order to provide custom matching logic func (w *Websocket) GetSubscription(key any) *subscription.Subscription { - if w == nil || w.subscriptions == nil || key == nil { + if w == nil || key == nil { + return nil + } + for _, c := range w.Connections { + if c.Subscriptions == nil { + continue + } + sub := c.Subscriptions.Get(key) + if sub != nil { + return sub + } + } + if w.subscriptions == nil { return nil } return w.subscriptions.Get(key) @@ -1128,10 +1136,19 @@ func (w *Websocket) GetSubscription(key any) *subscription.Subscription { // GetSubscriptions returns a new slice of the subscriptions func (w *Websocket) GetSubscriptions() subscription.List { - if w == nil || w.subscriptions == nil { + if w == nil { return nil } - return w.subscriptions.List() + var subs subscription.List + for _, c := range w.Connections { + if c.Subscriptions != nil { + subs = append(subs, c.Subscriptions.List()...) + } + } + if w.subscriptions != nil { + subs = append(subs, w.subscriptions.List()...) + } + return subs } // SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in a thread safe manner @@ -1170,39 +1187,17 @@ func checkWebsocketURL(s string) error { // checkSubscriptions checks subscriptions against the max subscription limit and if the subscription already exists // The subscription state is not considered when counting existing subscriptions func (w *Websocket) checkSubscriptions(conn Connection, subs subscription.List) error { - if conn != nil { - state, ok := w.Connections[conn] - if !ok { - return errors.New("connection ddetails not found") - } - - if state.Subscriptions == nil { - return fmt.Errorf("%w: Websocket.subscriptions", common.ErrNilPointer) - } - - existing := state.Subscriptions.Len() - if w.MaxSubscriptionsPerConnection > 0 && existing+len(subs) > w.MaxSubscriptionsPerConnection { - return fmt.Errorf("%w: current subscriptions: %v, incoming subscriptions: %v, max subscriptions per connection: %v - please reduce enabled pairs", - errSubscriptionsExceedsLimit, - existing, - len(subs), - w.MaxSubscriptionsPerConnection) - } - - for _, s := range subs { - if found := state.Subscriptions.Get(s); found != nil { - return fmt.Errorf("%w: %s", subscription.ErrDuplicate, s) - } - } - - return nil + var subscriptionStore *subscription.Store + if candidate, ok := w.Connections[conn]; ok { + subscriptionStore = candidate.Subscriptions + } else { + subscriptionStore = w.subscriptions } - - if w.subscriptions == nil { + if subscriptionStore == nil { return fmt.Errorf("%w: Websocket.subscriptions", common.ErrNilPointer) } - existing := w.subscriptions.Len() + existing := subscriptionStore.Len() if w.MaxSubscriptionsPerConnection > 0 && existing+len(subs) > w.MaxSubscriptionsPerConnection { return fmt.Errorf("%w: current subscriptions: %v, incoming subscriptions: %v, max subscriptions per connection: %v - please reduce enabled pairs", errSubscriptionsExceedsLimit, @@ -1212,7 +1207,7 @@ func (w *Websocket) checkSubscriptions(conn Connection, subs subscription.List) } for _, s := range subs { - if found := w.subscriptions.Get(s); found != nil { + if found := subscriptionStore.Get(s); found != nil { return fmt.Errorf("%w: %s", subscription.ErrDuplicate, s) } } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 026f3eb2396..04045097cd2 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -117,51 +117,68 @@ func TestSetup(t *testing.T) { t.Parallel() var w *Websocket err := w.Setup(nil) - assert.ErrorIs(t, err, errWebsocketIsNil, "Setup should error correctly") + assert.ErrorIs(t, err, errWebsocketIsNil) w = &Websocket{DataHandler: make(chan interface{})} err = w.Setup(nil) - assert.ErrorIs(t, err, errWebsocketSetupIsNil, "Setup should error correctly") + assert.ErrorIs(t, err, errWebsocketSetupIsNil) websocketSetup := &WebsocketSetup{} err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errExchangeConfigIsNil, "Setup should error correctly") + assert.ErrorIs(t, err, errExchangeConfigIsNil) websocketSetup.ExchangeConfig = &config.Exchange{} err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errExchangeConfigNameEmpty, "Setup should error correctly") + assert.ErrorIs(t, err, errExchangeConfigNameEmpty) websocketSetup.ExchangeConfig.Name = "testname" err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errWebsocketFeaturesIsUnset, "Setup should error correctly") + assert.ErrorIs(t, err, errWebsocketFeaturesIsUnset) websocketSetup.Features = &protocol.Features{} err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errConfigFeaturesIsNil, "Setup should error correctly") + assert.ErrorIs(t, err, errConfigFeaturesIsNil) websocketSetup.ExchangeConfig.Features = &config.FeaturesConfig{} + websocketSetup.Subscriber = func(subscription.List) error { return nil } // kicks off the setup + err = w.Setup(websocketSetup) + assert.ErrorIs(t, err, errWebsocketConnectorUnset) + websocketSetup.Subscriber = nil + + websocketSetup.Connector = func() error { return nil } + err = w.Setup(websocketSetup) + assert.ErrorIs(t, err, errWebsocketSubscriberUnset) + + websocketSetup.Subscriber = func(subscription.List) error { return nil } + w.features.Unsubscribe = true + err = w.Setup(websocketSetup) + assert.ErrorIs(t, err, errWebsocketUnsubscriberUnset) + + websocketSetup.Unsubscriber = func(subscription.List) error { return nil } + err = w.Setup(websocketSetup) + assert.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset) websocketSetup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil } err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errDefaultURLIsEmpty, "Setup should error correctly") + assert.ErrorIs(t, err, errDefaultURLIsEmpty) websocketSetup.DefaultURL = "test" err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errRunningURLIsEmpty, "Setup should error correctly") + assert.ErrorIs(t, err, errRunningURLIsEmpty) websocketSetup.RunningURL = "http://www.google.com" err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errInvalidWebsocketURL, "Setup should error correctly") + assert.ErrorIs(t, err, errInvalidWebsocketURL) websocketSetup.RunningURL = "wss://www.google.com" websocketSetup.RunningURLAuth = "http://www.google.com" err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errInvalidWebsocketURL, "Setup should error correctly") + assert.ErrorIs(t, err, errInvalidWebsocketURL) websocketSetup.RunningURLAuth = "wss://www.google.com" err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errInvalidTrafficTimeout, "Setup should error correctly") + assert.ErrorIs(t, err, errInvalidTrafficTimeout) websocketSetup.ExchangeConfig.WebsocketTrafficTimeout = time.Minute err = w.Setup(websocketSetup) @@ -497,12 +514,24 @@ func currySimpleSub(w *Websocket) func(subscription.List) error { } } +func currySimpleSubConn(w *Websocket) func(context.Context, Connection, subscription.List) error { + return func(_ context.Context, conn Connection, subs subscription.List) error { + return w.AddSuccessfulSubscriptions(conn, subs...) + } +} + func currySimpleUnsub(w *Websocket) func(subscription.List) error { return func(unsubs subscription.List) error { return w.RemoveSubscriptions(nil, unsubs...) } } +func currySimpleUnsubConn(w *Websocket) func(context.Context, Connection, subscription.List) error { + return func(_ context.Context, conn Connection, unsubs subscription.List) error { + return w.RemoveSubscriptions(conn, unsubs...) + } +} + // TestSubscribe logic test func TestSubscribeUnsubscribe(t *testing.T) { t.Parallel() @@ -545,6 +574,75 @@ func TestSubscribeUnsubscribe(t *testing.T) { err = ws.SubscribeToChannels(nil, subscription.List{nil}) assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription") + + multi := NewWebsocket() + set := *defaultSetup + // Values below are now not neccessary as this will be set per connection + // candidate in SetupNewConnection. + set.Connector = nil + set.Subscriber = nil + set.Unsubscriber = nil + set.GenerateSubscriptions = nil + set.DefaultURL = "" + set.RunningURL = "" + assert.NoError(t, multi.Setup(&set)) + + amazingCandidate := &ConnectionSetup{ + URL: "AMAZING", + Connector: func(ctx context.Context, c Connection) error { return nil }, + GenerateSubscriptions: ws.GenerateSubs, + Subscriber: func(ctx context.Context, c Connection, s subscription.List) error { + return currySimpleSubConn(multi)(ctx, c, s) + }, + Unsubscriber: func(ctx context.Context, c Connection, s subscription.List) error { + return currySimpleUnsubConn(multi)(ctx, c, s) + }, + Handler: func(ctx context.Context, b []byte) error { return nil }, + } + require.NoError(t, multi.SetupNewConnection(amazingCandidate)) + + amazingConn := multi.getConnectionFromSetup(amazingCandidate) + multi.Connections = map[Connection]*ConnectionCandidate{ + amazingConn: &multi.ConnectionManager[0], + } + + subs, err = amazingCandidate.GenerateSubscriptions() + require.NoError(t, err, "Generating test subscriptions should not error") + assert.NoError(t, new(Websocket).UnsubscribeChannels(nil, subs), "Should not error when w.subscriptions is nil") + assert.NoError(t, new(Websocket).UnsubscribeChannels(amazingConn, subs), "Should not error when w.subscriptions is nil") + assert.NoError(t, multi.UnsubscribeChannels(amazingConn, nil), "Unsubscribing from nil should not error") + assert.ErrorIs(t, multi.UnsubscribeChannels(amazingConn, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed") + assert.Nil(t, multi.GetSubscription(42), "GetSubscription on empty internal map should return") + + assert.ErrorIs(t, multi.SubscribeToChannels(nil, subs), common.ErrNilPointer, "If no connection is set, Subscribe should error") + + assert.NoError(t, multi.SubscribeToChannels(amazingConn, subs), "Basic Subscribing should not error") + assert.Len(t, multi.GetSubscriptions(), 4, "Should have 4 subscriptions") + bySub = multi.GetSubscription(subscription.Subscription{Channel: "TestSub"}) + if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") { + assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel") + assert.Same(t, bySub, subs[0], "GetSubscription returns the same pointer") + } + if assert.NotNil(t, multi.GetSubscription("purple"), "GetSubscription by string key should find a channel") { + assert.Equal(t, "TestSub2", multi.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel") + } + if assert.NotNil(t, multi.GetSubscription(testSubKey{"mauve"}), "GetSubscription by type key should find a channel") { + assert.Equal(t, "TestSub3", multi.GetSubscription(testSubKey{"mauve"}).Channel, "GetSubscription by type key should return a pointer a copy of the right channel") + } + if assert.NotNil(t, multi.GetSubscription(42), "GetSubscription by int key should find a channel") { + assert.Equal(t, "TestSub4", multi.GetSubscription(42).Channel, "GetSubscription by int key should return a pointer a copy of the right channel") + } + assert.Nil(t, multi.GetSubscription(nil), "GetSubscription by nil should return nil") + assert.Nil(t, multi.GetSubscription(45), "GetSubscription by invalid key should return nil") + assert.ErrorIs(t, multi.SubscribeToChannels(amazingConn, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed") + assert.NoError(t, multi.SubscribeToChannels(amazingConn, nil), "Subscribe to an nil List should not error") + assert.NoError(t, multi.UnsubscribeChannels(amazingConn, subs), "Unsubscribing should not error") + + amazingCandidate.Subscriber = func(context.Context, Connection, subscription.List) error { return errDastardlyReason } + assert.ErrorIs(t, multi.SubscribeToChannels(amazingConn, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber") + + err = multi.SubscribeToChannels(amazingConn, subscription.List{nil}) + assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription") } // TestResubscribe tests Resubscribing to existing subscriptions @@ -996,6 +1094,35 @@ func TestGetChannelDifference(t *testing.T) { subs, unsubs = w.GetChannelDifference(nil, subscription.List{{Channel: subscription.TickerChannel}}) require.Equal(t, 1, len(subs), "Should get the correct number of subs") assert.Equal(t, 1, len(unsubs), "Should get the correct number of unsubs") + + w = &Websocket{} + sweetConn := &WebsocketConnection{} + subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}}) + require.Equal(t, 1, len(subs)) + require.Empty(t, unsubs, "Should get no unsubs") + + w.Connections = map[Connection]*ConnectionCandidate{ + sweetConn: {Details: &ConnectionSetup{URL: "ws://localhost:8080/ws"}}, + } + + naughtyConn := &WebsocketConnection{} + subs, unsubs = w.GetChannelDifference(naughtyConn, subscription.List{{Channel: subscription.CandlesChannel}}) + require.Equal(t, 1, len(subs)) + require.Empty(t, unsubs, "Should get no unsubs") + + subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}}) + require.Equal(t, 1, len(subs)) + require.Empty(t, unsubs, "Should get no unsubs") + + w.Connections[sweetConn].Subscriptions.Add(&subscription.Subscription{Channel: subscription.CandlesChannel}) + subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}}) + require.Empty(t, subs, "Should get no subs") + require.Empty(t, unsubs, "Should get no unsubs") + + subs, unsubs = w.GetChannelDifference(sweetConn, nil) + require.Empty(t, subs, "Should get no subs") + require.Equal(t, 1, len(unsubs)) + } // GenSubs defines a theoretical exchange with pair management @@ -1169,6 +1296,54 @@ func TestSetupNewConnection(t *testing.T) { err = web.SetupNewConnection(&ConnectionSetup{URL: "urlstring", Authenticated: true}) assert.NoError(t, err, "SetupNewConnection should not error") + + // Test connection candidates for multi connection tracking. + multi := NewWebsocket() + set := *defaultSetup + + // Values below are now not neccessary as this will be set per connection + // candidate in SetupNewConnection. + set.Connector = nil + set.Subscriber = nil + set.Unsubscriber = nil + set.GenerateSubscriptions = nil + set.DefaultURL = "" + set.RunningURL = "" + + require.NoError(t, multi.Setup(&set)) + + connSetup := &ConnectionSetup{} + err = multi.SetupNewConnection(connSetup) + require.ErrorIs(t, err, errDefaultURLIsEmpty) + + connSetup.URL = "urlstring" + err = multi.SetupNewConnection(connSetup) + require.ErrorIs(t, err, errWebsocketConnectorUnset) + + connSetup.Connector = func(context.Context, Connection) error { return nil } + err = multi.SetupNewConnection(connSetup) + require.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset) + + connSetup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil } + err = multi.SetupNewConnection(connSetup) + require.ErrorIs(t, err, errWebsocketSubscriberUnset) + + connSetup.Subscriber = func(context.Context, Connection, subscription.List) error { return nil } + err = multi.SetupNewConnection(connSetup) + require.ErrorIs(t, err, errWebsocketUnsubscriberUnset) + + connSetup.Unsubscriber = func(context.Context, Connection, subscription.List) error { return nil } + err = multi.SetupNewConnection(connSetup) + require.ErrorIs(t, err, errWebsocketDataHandlerUnset) + + connSetup.Handler = func(context.Context, []byte) error { return nil } + err = multi.SetupNewConnection(connSetup) + require.NoError(t, err) + + require.Len(t, multi.ConnectionManager, 1) + + require.Nil(t, multi.AuthConn) + require.Nil(t, multi.Conn) } func TestWebsocketConnectionShutdown(t *testing.T) { diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index b88c55c5488..dc6dd20256e 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -29,21 +29,6 @@ const ( connectedState ) -// // ConnectionAssociation contains the connection details and subscriptions -// type ConnectionAssociation struct { -// Subscriptions *subscription.Store -// Details *ConnectionSetup -// } - -// ConnectionCandidate contains the connection setup details to be used when -// attempting a new connection. It also contains the subscriptions that are -// associated with the specifc connection. -type ConnectionCandidate struct { - Details *ConnectionSetup - Subscriptions *subscription.Store - Connection Connection // TODO: Upgrade to slice of connections. -} - // Websocket defines a return type for websocket connections via the interface // wrapper for routine processing type Websocket struct { From c7d2b620a1f874bfeae1ffea6c5fa390e4bb1e74 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Tue, 16 Jul 2024 14:18:56 +1000 Subject: [PATCH 009/138] glorious:nits + proxy handling --- exchanges/gateio/gateio_wrapper.go | 9 +- .../gateio/gateio_ws_delivery_futures.go | 2 +- exchanges/gateio/gateio_ws_futures.go | 2 +- exchanges/gateio/gateio_ws_option.go | 2 +- exchanges/stream/stream_types.go | 4 +- exchanges/stream/websocket.go | 182 ++++++++++-------- exchanges/stream/websocket_connection.go | 6 +- exchanges/stream/websocket_test.go | 38 ++-- exchanges/stream/websocket_types.go | 11 +- 9 files changed, 143 insertions(+), 113 deletions(-) diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 1c996afe219..a098e5a473c 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -193,10 +193,11 @@ func (g *Gateio) Setup(exch *config.Exchange) error { } err = g.Websocket.Setup(&stream.WebsocketSetup{ - ExchangeConfig: exch, - Features: &g.Features.Supports.WebsocketCapabilities, - FillsFeed: g.Features.Enabled.FillsFeed, - TradeFeed: g.Features.Enabled.TradeFeed, + ExchangeConfig: exch, + Features: &g.Features.Supports.WebsocketCapabilities, + FillsFeed: g.Features.Enabled.FillsFeed, + TradeFeed: g.Features.Enabled.TradeFeed, + UseMultiConnectionManagement: true, }) if err != nil { return err diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_ws_delivery_futures.go index da0dc2130b8..0f70c82d32f 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_ws_delivery_futures.go @@ -54,7 +54,7 @@ func (g *Gateio) WsDeliveryFuturesConnect(ctx context.Context, conn stream.Conne } pingMessage, err := json.Marshal(WsInput{ ID: conn.GenerateMessageID(false), - Time: time.Now().Unix(), // TODO: func for dynamic time + Time: time.Now().Unix(), // TODO: Func for dynamic time as this will be the same time for every ping message. Channel: futuresPingChannel, }) if err != nil { diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index 561a1e73c6d..f6764b3bcaf 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -77,7 +77,7 @@ func (g *Gateio) WsFuturesConnect(ctx context.Context, conn stream.Connection) e } pingMessage, err := json.Marshal(WsInput{ ID: conn.GenerateMessageID(false), - Time: time.Now().Unix(), // TODO: This should be a timer function. + Time: time.Now().Unix(), // TODO: Func for dynamic time as this will be the same time for every ping message. Channel: futuresPingChannel, }) if err != nil { diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_ws_option.go index 9524cd6d34d..5c5c7d514fb 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_ws_option.go @@ -87,7 +87,7 @@ func (g *Gateio) WsOptionsConnect(ctx context.Context, conn stream.Connection) e } pingMessage, err := json.Marshal(WsInput{ ID: conn.GenerateMessageID(false), - Time: time.Now().Unix(), + Time: time.Now().Unix(), // TODO: Func for dynamic time as this will be the same time for every ping message. Channel: optionsPingChannel, }) if err != nil { diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index 0cfa66a06ca..65e6235fa0c 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -69,8 +69,8 @@ type ConnectionSetup struct { // attempting a new connection. It also contains the subscriptions that are // associated with the specifc connection. type ConnectionCandidate struct { - // Details contains the connection setup details - Details *ConnectionSetup + // Setup contains the connection setup details + Setup *ConnectionSetup // Subscriptions contains the subscriptions that are associated with the // specific connection(s) Subscriptions *subscription.Store diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index e569d53fd77..4e087f0bc9c 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -65,6 +65,8 @@ var ( errAlreadyReconnecting = errors.New("websocket in the process of reconnection") errConnSetup = errors.New("error in connection setup") errNoPendingConnections = errors.New("no pending connections, call SetupNewConnection first") + errConnectionCandidateDuplication = errors.New("connection candidate duplication") + errCannotChangeConnectionURL = errors.New("cannot change connection URL when using multi connection management") ) var ( @@ -90,7 +92,7 @@ func NewWebsocket() *Websocket { subscriptions: subscription.NewStore(), features: &protocol.Features{}, Orderbook: buffer.Orderbook{}, - Connections: make(map[Connection]*ConnectionCandidate), + connections: make(map[Connection]*ConnectionCandidate), } } @@ -131,10 +133,11 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { } w.setEnabled(s.ExchangeConfig.Features.Enabled.Websocket) - // If any fields here are set, assume that the previous global connector - // pattern is being used. - // TODO: Shift everything to connection setup when all exchanges are updated. - if s.Connector != nil || s.Subscriber != nil || s.Unsubscriber != nil || s.GenerateSubscriptions != nil || s.DefaultURL != "" || s.RunningURL != "" { + w.useMultiConnectionManagement = s.UseMultiConnectionManagement + + if !w.useMultiConnectionManagement { + // TODO: Remove this block when all exchanges are updated and backwards + // compatibility is no longer required. if s.Connector == nil { return fmt.Errorf("%w: %w", errConnSetup, errWebsocketConnectorUnset) } @@ -226,10 +229,9 @@ func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { c.ConnectionLevelReporter = globalReporter } - // If connector is nil, we assume that the connection and supporting - // functions are defined per connection. Else we use the global connector - // and supporting functions for backwards compatibility. - if w.connector == nil { + if w.useMultiConnectionManagement { + // The connection and supporting functions are defined per connection + // and the connection candidate is stored in the connection manager. if c.URL == "" { return fmt.Errorf("%w: %w", errConnSetup, errDefaultURLIsEmpty) } @@ -248,8 +250,15 @@ func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { if c.Handler == nil { return fmt.Errorf("%w: %w", errConnSetup, errWebsocketDataHandlerUnset) } - w.ConnectionManager = append(w.ConnectionManager, ConnectionCandidate{ - Details: c, + + for x := range w.connectionManager { + if w.connectionManager[x].Setup.URL == c.URL { + return fmt.Errorf("%w: %w", errConnSetup, errConnectionCandidateDuplication) + } + } + + w.connectionManager = append(w.connectionManager, ConnectionCandidate{ + Setup: c, Subscriptions: subscription.NewStore(), }) return nil @@ -339,22 +348,23 @@ func (w *Websocket) Connect() error { return nil } - // hasStableConnection is used to determine if the websocket has a stable - // connection. If it does not, the websocket will be set to disconnected. - hasStableConnection := false - defer w.setStateFromHasStableConnection(&hasStableConnection) - - if len(w.ConnectionManager) == 0 { + if len(w.connectionManager) == 0 { + w.setState(disconnectedState) return fmt.Errorf("cannot connect: %w", errNoPendingConnections) } + // Assume connected state and if there are any issues below can call Shutdown + w.setState(connectedState) + + var multiConnectError error // TODO: Implement concurrency below. - for i := range w.ConnectionManager { - if w.ConnectionManager[i].Details.GenerateSubscriptions == nil { - return fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.ConnectionManager[i].Details.URL, errWebsocketSubscriptionsGeneratorUnset) + for i := range w.connectionManager { + if w.connectionManager[i].Setup.GenerateSubscriptions == nil { + multiConnectError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errWebsocketSubscriptionsGeneratorUnset) + break } - subs, err := w.ConnectionManager[i].Details.GenerateSubscriptions() // regenerate state on new connection + subs, err := w.connectionManager[i].Setup.GenerateSubscriptions() // regenerate state on new connection if err != nil { if errors.Is(err, asset.ErrNotEnabled) { if w.verbose { @@ -362,7 +372,8 @@ func (w *Websocket) Connect() error { } continue // Non-fatal error, we can continue to the next connection } - return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) + multiConnectError = fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) + break } if len(subs) == 0 { @@ -373,38 +384,53 @@ func (w *Websocket) Connect() error { continue } - if w.ConnectionManager[i].Details.Connector == nil { - return fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.ConnectionManager[i].Details.URL, errNoConnectFunc) + if w.connectionManager[i].Setup.Connector == nil { + multiConnectError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errNoConnectFunc) + break } - if w.ConnectionManager[i].Details.Handler == nil { - return fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.ConnectionManager[i].Details.URL, errWebsocketDataHandlerUnset) + if w.connectionManager[i].Setup.Handler == nil { + multiConnectError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errWebsocketDataHandlerUnset) + break } - if w.ConnectionManager[i].Details.Subscriber == nil { - return fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.ConnectionManager[i].Details.URL, errWebsocketSubscriberUnset) + if w.connectionManager[i].Setup.Subscriber == nil { + multiConnectError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errWebsocketSubscriberUnset) + break } // TODO: Add window for max subscriptions per connection, to spawn new connections if needed. - conn := w.getConnectionFromSetup(w.ConnectionManager[i].Details) + conn := w.getConnectionFromSetup(w.connectionManager[i].Setup) - err = w.ConnectionManager[i].Details.Connector(context.TODO(), conn) + err = w.connectionManager[i].Setup.Connector(context.TODO(), conn) if err != nil { - return fmt.Errorf("%v Error connecting %w", w.exchangeName, err) + multiConnectError = fmt.Errorf("%v Error connecting %w", w.exchangeName, err) + break } - hasStableConnection = true - w.Wg.Add(1) - go w.Reader(context.TODO(), conn, w.ConnectionManager[i].Details.Handler) + go w.Reader(context.TODO(), conn, w.connectionManager[i].Setup.Handler) - w.Connections[conn] = &w.ConnectionManager[i] + w.connections[conn] = &w.connectionManager[i] - err = w.ConnectionManager[i].Details.Subscriber(context.TODO(), conn, subs) + err = w.connectionManager[i].Setup.Subscriber(context.TODO(), conn, subs) if err != nil { - return fmt.Errorf("%v Error subscribing %w", w.exchangeName, err) + multiConnectError = fmt.Errorf("%v Error subscribing %w", w.exchangeName, err) + break } - w.ConnectionManager[i].Connection = conn + w.connectionManager[i].Connection = conn + } + + if multiConnectError != nil { + for conn, candidate := range w.connections { + if err := conn.Shutdown(); err != nil { + log.Errorln(log.WebsocketMgr, err) + } + candidate.Subscriptions.Clear() + } + clear(w.connections) + w.setState(disconnectedState) + return multiConnectError } if !w.IsConnectionMonitorRunning() { @@ -417,14 +443,6 @@ func (w *Websocket) Connect() error { return nil } -func (w *Websocket) setStateFromHasStableConnection(hasStableConnection *bool) { - if *hasStableConnection { - w.setState(connectedState) - } else { - w.setState(disconnectedState) - } -} - // Disable disables the exchange websocket protocol // Note that connectionMonitor will be responsible for shutting down the websocket after disabling func (w *Websocket) Disable() error { @@ -565,16 +583,16 @@ func (w *Websocket) Shutdown() error { defer w.Orderbook.FlushBuffer() // Shutdown managed connections - for conn := range w.Connections { + for conn := range w.connections { if err := conn.Shutdown(); err != nil { return err } } // Clean map of old connections - clear(w.Connections) + clear(w.connections) // Flush any subscriptions from last connection across any managed connections - for x := range w.ConnectionManager { - w.ConnectionManager[x].Subscriptions.Clear() + for x := range w.connectionManager { + w.connectionManager[x].Subscriptions.Clear() } if w.Conn != nil { @@ -630,26 +648,26 @@ func (w *Websocket) FlushChannels() error { } return w.SubscribeToChannels(nil, subs) } - for x := range w.ConnectionManager { - if w.ConnectionManager[x].Details.GenerateSubscriptions == nil { + for x := range w.connectionManager { + if w.connectionManager[x].Setup.GenerateSubscriptions == nil { continue } - newsubs, err := w.ConnectionManager[x].Details.GenerateSubscriptions() + newsubs, err := w.connectionManager[x].Setup.GenerateSubscriptions() if err != nil { if errors.Is(err, asset.ErrNotEnabled) { continue } return err } - subs, unsubs := w.GetChannelDifference(w.ConnectionManager[x].Connection, newsubs) + subs, unsubs := w.GetChannelDifference(w.connectionManager[x].Connection, newsubs) if len(unsubs) != 0 && w.features.Unsubscribe { - err = w.UnsubscribeChannels(w.ConnectionManager[x].Connection, unsubs) + err = w.UnsubscribeChannels(w.connectionManager[x].Connection, unsubs) if err != nil { return err } } if len(subs) != 0 { - err = w.SubscribeToChannels(w.ConnectionManager[x].Connection, subs) + err = w.SubscribeToChannels(w.connectionManager[x].Connection, subs) if err != nil { return err } @@ -677,11 +695,11 @@ func (w *Websocket) FlushChannels() error { return nil } - for x := range w.ConnectionManager { - if w.ConnectionManager[x].Details.GenerateSubscriptions == nil { + for x := range w.connectionManager { + if w.connectionManager[x].Setup.GenerateSubscriptions == nil { continue } - newsubs, err := w.ConnectionManager[x].Details.GenerateSubscriptions() + newsubs, err := w.connectionManager[x].Setup.GenerateSubscriptions() if err != nil { if errors.Is(err, asset.ErrNotEnabled) { continue @@ -690,8 +708,8 @@ func (w *Websocket) FlushChannels() error { } if len(newsubs) != 0 { // Purge subscription list as there will be conflicts - w.Connections[w.ConnectionManager[x].Connection].Subscriptions.Clear() - err = w.SubscribeToChannels(w.ConnectionManager[x].Connection, newsubs) + w.connections[w.connectionManager[x].Connection].Subscriptions.Clear() + err = w.SubscribeToChannels(w.connectionManager[x].Connection, newsubs) if err != nil { return err } @@ -838,6 +856,10 @@ func (w *Websocket) CanUseAuthenticatedWebsocketForWrapper() bool { // SetWebsocketURL sets websocket URL and can refresh underlying connections func (w *Websocket) SetWebsocketURL(url string, auth, reconnect bool) error { + if w.useMultiConnectionManagement { + // TODO: Enable multi connection management to change URL + return fmt.Errorf("%s: %w", w.exchangeName, errCannotChangeConnectionURL) + } defaultVals := url == "" || url == config.WebsocketURLNonDefaultMessage if auth { if defaultVals { @@ -851,10 +873,7 @@ func (w *Websocket) SetWebsocketURL(url string, auth, reconnect bool) error { w.runningURLAuth = url if w.verbose { - log.Debugf(log.WebsocketMgr, - "%s websocket: setting authenticated websocket URL: %s\n", - w.exchangeName, - url) + log.Debugf(log.WebsocketMgr, "%s websocket: setting authenticated websocket URL: %s\n", w.exchangeName, url) } if w.AuthConn != nil { @@ -871,10 +890,7 @@ func (w *Websocket) SetWebsocketURL(url string, auth, reconnect bool) error { w.runningURL = url if w.verbose { - log.Debugf(log.WebsocketMgr, - "%s websocket: setting unauthenticated websocket URL: %s\n", - w.exchangeName, - url) + log.Debugf(log.WebsocketMgr, "%s websocket: setting unauthenticated websocket URL: %s\n", w.exchangeName, url) } if w.Conn != nil { @@ -883,10 +899,7 @@ func (w *Websocket) SetWebsocketURL(url string, auth, reconnect bool) error { } if w.IsConnected() && reconnect { - log.Debugf(log.WebsocketMgr, - "%s websocket: flushing websocket connection to %s\n", - w.exchangeName, - url) + log.Debugf(log.WebsocketMgr, "%s websocket: flushing websocket connection to %s\n", w.exchangeName, url) return w.Shutdown() } return nil @@ -917,6 +930,9 @@ func (w *Websocket) SetProxyAddress(proxyAddr string) error { log.Debugf(log.ExchangeSys, "%s websocket: removing websocket proxy", w.exchangeName) } + for conn := range w.connections { + conn.SetProxy(proxyAddr) + } if w.Conn != nil { w.Conn.SetProxy(proxyAddr) } @@ -953,7 +969,7 @@ func (w *Websocket) GetName() string { // and the new subscription list when pairs are disabled or enabled. func (w *Websocket) GetChannelDifference(conn Connection, newSubs subscription.List) (sub, unsub subscription.List) { var subscriptionStore **subscription.Store - if candidate, ok := w.Connections[conn]; ok { + if candidate, ok := w.connections[conn]; ok { subscriptionStore = &candidate.Subscriptions } else { subscriptionStore = &w.subscriptions @@ -970,7 +986,7 @@ func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.L return nil // No channels to unsubscribe from is not an error } - if candidate, ok := w.Connections[conn]; ok { + if candidate, ok := w.connections[conn]; ok { if candidate.Subscriptions == nil { return nil // No channels to unsubscribe from is not an error } @@ -979,7 +995,7 @@ func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.L return fmt.Errorf("%w: %s", subscription.ErrNotFound, s) } } - return candidate.Details.Unsubscriber(context.TODO(), conn, channels) + return candidate.Setup.Unsubscriber(context.TODO(), conn, channels) } if w.subscriptions == nil { @@ -1017,8 +1033,8 @@ func (w *Websocket) SubscribeToChannels(conn Connection, subs subscription.List) return err } - if candidate, ok := w.Connections[conn]; ok { - return candidate.Details.Subscriber(context.TODO(), conn, subs) + if candidate, ok := w.connections[conn]; ok { + return candidate.Setup.Subscriber(context.TODO(), conn, subs) } if w.Subscriber == nil { @@ -1061,7 +1077,7 @@ func (w *Websocket) AddSuccessfulSubscriptions(conn Connection, subs ...*subscri } var subscriptionStore **subscription.Store - if candidate, ok := w.Connections[conn]; ok { + if candidate, ok := w.connections[conn]; ok { subscriptionStore = &candidate.Subscriptions } else { subscriptionStore = &w.subscriptions @@ -1090,7 +1106,7 @@ func (w *Websocket) RemoveSubscriptions(conn Connection, subs ...*subscription.S } var subscriptionStore *subscription.Store - if candidate, ok := w.Connections[conn]; ok { + if candidate, ok := w.connections[conn]; ok { subscriptionStore = candidate.Subscriptions } else { subscriptionStore = w.subscriptions @@ -1119,7 +1135,7 @@ func (w *Websocket) GetSubscription(key any) *subscription.Subscription { if w == nil || key == nil { return nil } - for _, c := range w.Connections { + for _, c := range w.connections { if c.Subscriptions == nil { continue } @@ -1140,7 +1156,7 @@ func (w *Websocket) GetSubscriptions() subscription.List { return nil } var subs subscription.List - for _, c := range w.Connections { + for _, c := range w.connections { if c.Subscriptions != nil { subs = append(subs, c.Subscriptions.List()...) } @@ -1188,7 +1204,7 @@ func checkWebsocketURL(s string) error { // The subscription state is not considered when counting existing subscriptions func (w *Websocket) checkSubscriptions(conn Connection, subs subscription.List) error { var subscriptionStore *subscription.Store - if candidate, ok := w.Connections[conn]; ok { + if candidate, ok := w.connections[conn]; ok { subscriptionStore = candidate.Subscriptions } else { subscriptionStore = w.subscriptions diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 2871b1d1d23..2157dd9a01a 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -70,11 +70,12 @@ func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header w.Connection, conStatus, err = dialer.Dial(w.URL, headers) if err != nil { if conStatus != nil { + _ = conStatus.Body.Close() return fmt.Errorf("%s websocket connection: %v %v %v Error: %w", w.ExchangeName, w.URL, conStatus, conStatus.StatusCode, err) } return fmt.Errorf("%s websocket connection: %v Error: %w", w.ExchangeName, w.URL, err) } - defer conStatus.Body.Close() // TODO: Close on error above. This is a potential resource leak. + _ = conStatus.Body.Close() if w.Verbose { log.Infof(log.WebsocketMgr, "%v Websocket connected to %s\n", w.ExchangeName, w.URL) @@ -102,11 +103,12 @@ func (w *WebsocketConnection) DialContext(ctx context.Context, dialer *websocket w.Connection, conStatus, err = dialer.DialContext(ctx, w.URL, headers) if err != nil { if conStatus != nil { + _ = conStatus.Body.Close() return fmt.Errorf("%s websocket connection: %v %v %v Error: %w", w.ExchangeName, w.URL, conStatus, conStatus.StatusCode, err) } return fmt.Errorf("%s websocket connection: %v Error: %w", w.ExchangeName, w.URL, err) } - defer conStatus.Body.Close() // TODO: Close on error above. This is a potential resource leak. + _ = conStatus.Body.Close() if w.Verbose { log.Infof(log.WebsocketMgr, "%v Websocket connected to %s\n", w.ExchangeName, w.URL) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 04045097cd2..d55223e1cef 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -358,59 +358,58 @@ func TestConnectionMessageErrors(t *testing.T) { err = ws.Connect() assert.ErrorIs(t, err, errNoPendingConnections, "Connect should error correctly") - ws.ConnectionManager = []ConnectionCandidate{{Details: &ConnectionSetup{URL: "ws://localhost:8080/ws"}}} + ws.connectionManager = []ConnectionCandidate{{Setup: &ConnectionSetup{URL: "ws://localhost:8080/ws"}}} err = ws.Connect() require.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset) - ws.ConnectionManager[0].Details.GenerateSubscriptions = func() (subscription.List, error) { + ws.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) { return nil, errDastardlyReason } err = ws.Connect() require.ErrorIs(t, err, errDastardlyReason) - ws.ConnectionManager[0].Details.GenerateSubscriptions = func() (subscription.List, error) { + ws.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) { return subscription.List{{}}, nil } err = ws.Connect() require.ErrorIs(t, err, errNoConnectFunc) - ws.ConnectionManager[0].Details.Connector = func(context.Context, Connection) error { + ws.connectionManager[0].Setup.Connector = func(context.Context, Connection) error { return errDastardlyReason } err = ws.Connect() require.ErrorIs(t, err, errWebsocketDataHandlerUnset) - ws.ConnectionManager[0].Details.Handler = func(context.Context, []byte) error { + ws.connectionManager[0].Setup.Handler = func(context.Context, []byte) error { return errDastardlyReason } err = ws.Connect() require.ErrorIs(t, err, errWebsocketSubscriberUnset) - ws.ConnectionManager[0].Details.Subscriber = func(context.Context, Connection, subscription.List) error { + ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error { return errDastardlyReason } err = ws.Connect() require.ErrorIs(t, err, errDastardlyReason) - ws.ConnectionManager[0].Details.Connector = func(context.Context, Connection) error { + ws.connectionManager[0].Setup.Connector = func(context.Context, Connection) error { return nil } err = ws.Connect() require.ErrorIs(t, err, errDastardlyReason) - ws.ConnectionManager[0].Details.Handler = func(context.Context, []byte) error { + ws.connectionManager[0].Setup.Handler = func(context.Context, []byte) error { return nil } - require.NoError(t, ws.Shutdown()) err = ws.Connect() require.ErrorIs(t, err, errDastardlyReason) - ws.ConnectionManager[0].Details.Subscriber = func(context.Context, Connection, subscription.List) error { + ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error { return nil } - require.NoError(t, ws.Shutdown()) err = ws.Connect() require.NoError(t, err) + require.NoError(t, ws.Shutdown()) } func TestWebsocket(t *testing.T) { @@ -579,6 +578,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { set := *defaultSetup // Values below are now not neccessary as this will be set per connection // candidate in SetupNewConnection. + set.UseMultiConnectionManagement = true set.Connector = nil set.Subscriber = nil set.Unsubscriber = nil @@ -602,8 +602,8 @@ func TestSubscribeUnsubscribe(t *testing.T) { require.NoError(t, multi.SetupNewConnection(amazingCandidate)) amazingConn := multi.getConnectionFromSetup(amazingCandidate) - multi.Connections = map[Connection]*ConnectionCandidate{ - amazingConn: &multi.ConnectionManager[0], + multi.connections = map[Connection]*ConnectionCandidate{ + amazingConn: &multi.connectionManager[0], } subs, err = amazingCandidate.GenerateSubscriptions() @@ -1101,8 +1101,8 @@ func TestGetChannelDifference(t *testing.T) { require.Equal(t, 1, len(subs)) require.Empty(t, unsubs, "Should get no unsubs") - w.Connections = map[Connection]*ConnectionCandidate{ - sweetConn: {Details: &ConnectionSetup{URL: "ws://localhost:8080/ws"}}, + w.connections = map[Connection]*ConnectionCandidate{ + sweetConn: {Setup: &ConnectionSetup{URL: "ws://localhost:8080/ws"}}, } naughtyConn := &WebsocketConnection{} @@ -1114,7 +1114,7 @@ func TestGetChannelDifference(t *testing.T) { require.Equal(t, 1, len(subs)) require.Empty(t, unsubs, "Should get no unsubs") - w.Connections[sweetConn].Subscriptions.Add(&subscription.Subscription{Channel: subscription.CandlesChannel}) + w.connections[sweetConn].Subscriptions.Add(&subscription.Subscription{Channel: subscription.CandlesChannel}) subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}}) require.Empty(t, subs, "Should get no subs") require.Empty(t, unsubs, "Should get no unsubs") @@ -1303,6 +1303,7 @@ func TestSetupNewConnection(t *testing.T) { // Values below are now not neccessary as this will be set per connection // candidate in SetupNewConnection. + set.UseMultiConnectionManagement = true set.Connector = nil set.Subscriber = nil set.Unsubscriber = nil @@ -1340,10 +1341,13 @@ func TestSetupNewConnection(t *testing.T) { err = multi.SetupNewConnection(connSetup) require.NoError(t, err) - require.Len(t, multi.ConnectionManager, 1) + require.Len(t, multi.connectionManager, 1) require.Nil(t, multi.AuthConn) require.Nil(t, multi.Conn) + + err = multi.SetupNewConnection(connSetup) + require.ErrorIs(t, err, errConnectionCandidateDuplication) } func TestWebsocketConnectionShutdown(t *testing.T) { diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index dc6dd20256e..7c0c443192f 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -52,10 +52,10 @@ type Websocket struct { // ConnectionManager contains the connection candidates and the current // connections - ConnectionManager []ConnectionCandidate + connectionManager []ConnectionCandidate // Connections contains the current connections with their associated // connection candidates - Connections map[Connection]*ConnectionCandidate + connections map[Connection]*ConnectionCandidate subscriptions *subscription.Store @@ -66,6 +66,8 @@ type Websocket struct { // GenerateSubs function for exchange specific generating subscriptions from Features.Subscriptions, Pairs and Assets GenerateSubs func() (subscription.List, error) + useMultiConnectionManagement bool + DataHandler chan interface{} ToRoutine chan interface{} @@ -119,6 +121,11 @@ type WebsocketSetup struct { // Local orderbook buffer config values OrderbookBufferConfig buffer.Config + // UseMultiConnectionManagement allows this connection to be managed by the + // connection manager. If false, this will default to the global fields + // provided in this struct. + UseMultiConnectionManagement bool + TradeFeed bool // Fill data config values From fc281eeee000b44a29a6d30241b4d246d0a31267 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Tue, 16 Jul 2024 14:20:52 +1000 Subject: [PATCH 010/138] Spelling --- exchanges/stream/stream_types.go | 2 +- exchanges/stream/websocket_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index 65e6235fa0c..f757ff92909 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -67,7 +67,7 @@ type ConnectionSetup struct { // ConnectionCandidate contains the connection setup details to be used when // attempting a new connection. It also contains the subscriptions that are -// associated with the specifc connection. +// associated with the specific connection. type ConnectionCandidate struct { // Setup contains the connection setup details Setup *ConnectionSetup diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index d55223e1cef..b481bdf0cc8 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -576,7 +576,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { multi := NewWebsocket() set := *defaultSetup - // Values below are now not neccessary as this will be set per connection + // Values below are now not necessary as this will be set per connection // candidate in SetupNewConnection. set.UseMultiConnectionManagement = true set.Connector = nil @@ -1301,7 +1301,7 @@ func TestSetupNewConnection(t *testing.T) { multi := NewWebsocket() set := *defaultSetup - // Values below are now not neccessary as this will be set per connection + // Values below are now not necessary as this will be set per connection // candidate in SetupNewConnection. set.UseMultiConnectionManagement = true set.Connector = nil From eaa44ba00f94d9714df9b178b738d3fe2507cf6b Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Tue, 16 Jul 2024 14:29:26 +1000 Subject: [PATCH 011/138] linter: fixerino --- exchanges/stream/websocket_test.go | 9 +++++---- exchanges/stream/websocket_types.go | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index b481bdf0cc8..92edd545350 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -589,7 +589,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { amazingCandidate := &ConnectionSetup{ URL: "AMAZING", - Connector: func(ctx context.Context, c Connection) error { return nil }, + Connector: func(context.Context, Connection) error { return nil }, GenerateSubscriptions: ws.GenerateSubs, Subscriber: func(ctx context.Context, c Connection, s subscription.List) error { return currySimpleSubConn(multi)(ctx, c, s) @@ -597,7 +597,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { Unsubscriber: func(ctx context.Context, c Connection, s subscription.List) error { return currySimpleUnsubConn(multi)(ctx, c, s) }, - Handler: func(ctx context.Context, b []byte) error { return nil }, + Handler: func(context.Context, []byte) error { return nil }, } require.NoError(t, multi.SetupNewConnection(amazingCandidate)) @@ -1114,7 +1114,9 @@ func TestGetChannelDifference(t *testing.T) { require.Equal(t, 1, len(subs)) require.Empty(t, unsubs, "Should get no unsubs") - w.connections[sweetConn].Subscriptions.Add(&subscription.Subscription{Channel: subscription.CandlesChannel}) + err := w.connections[sweetConn].Subscriptions.Add(&subscription.Subscription{Channel: subscription.CandlesChannel}) + require.NoError(t, err) + subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}}) require.Empty(t, subs, "Should get no subs") require.Empty(t, unsubs, "Should get no unsubs") @@ -1122,7 +1124,6 @@ func TestGetChannelDifference(t *testing.T) { subs, unsubs = w.GetChannelDifference(sweetConn, nil) require.Empty(t, subs, "Should get no subs") require.Equal(t, 1, len(unsubs)) - } // GenSubs defines a theoretical exchange with pair management diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 7c0c443192f..529798a36df 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -121,7 +121,7 @@ type WebsocketSetup struct { // Local orderbook buffer config values OrderbookBufferConfig buffer.Config - // UseMultiConnectionManagement allows this connection to be managed by the + // UseMultiConnectionManagement allows the connections to be managed by the // connection manager. If false, this will default to the global fields // provided in this struct. UseMultiConnectionManagement bool From a8debf995b428468b535ba46bf9d881b33e0fb3a Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Tue, 16 Jul 2024 14:47:12 +1000 Subject: [PATCH 012/138] instead of nil, dont do nil. --- docs/ADD_NEW_EXCHANGE.md | 2 +- exchanges/binance/binance_websocket.go | 6 +++--- exchanges/binanceus/binanceus_websocket.go | 2 +- exchanges/bitfinex/bitfinex_test.go | 16 ++++++++-------- exchanges/bitfinex/bitfinex_websocket.go | 8 ++++---- exchanges/bithumb/bithumb_websocket.go | 2 +- exchanges/bitmex/bitmex_websocket.go | 4 ++-- exchanges/bitstamp/bitstamp_websocket.go | 4 ++-- exchanges/btcmarkets/btcmarkets_websocket.go | 4 ++-- exchanges/btse/btse_websocket.go | 4 ++-- exchanges/coinbasepro/coinbasepro_websocket.go | 4 ++-- exchanges/coinut/coinut_websocket.go | 4 ++-- exchanges/gemini/gemini_websocket.go | 4 ++-- exchanges/hitbtc/hitbtc_websocket.go | 4 ++-- exchanges/huobi/huobi_websocket.go | 4 ++-- exchanges/kraken/kraken_websocket.go | 10 ++++++++-- exchanges/kucoin/kucoin_websocket.go | 4 ++-- exchanges/okcoin/okcoin_websocket.go | 12 ++++++------ exchanges/okx/okx_websocket.go | 12 ++++++------ exchanges/poloniex/poloniex_websocket.go | 4 ++-- exchanges/stream/websocket.go | 15 +++++++++++---- exchanges/stream/websocket_test.go | 14 +++++++------- 22 files changed, 78 insertions(+), 65 deletions(-) diff --git a/docs/ADD_NEW_EXCHANGE.md b/docs/ADD_NEW_EXCHANGE.md index 16699b914e2..9a4ea88288f 100644 --- a/docs/ADD_NEW_EXCHANGE.md +++ b/docs/ADD_NEW_EXCHANGE.md @@ -1077,7 +1077,7 @@ channels: continue } // When we have a successful unsubscription, we can alert our internal management system of the success. - f.Websocket.RemoveSubscriptions(nil, channelsToUnsubscribe[i]) + f.Websocket.RemoveSubscriptions(f.Websocket.Conn, channelsToUnsubscribe[i]) } if errs != nil { return errs diff --git a/exchanges/binance/binance_websocket.go b/exchanges/binance/binance_websocket.go index ad4aecf9b4d..f9f4b41c9c7 100644 --- a/exchanges/binance/binance_websocket.go +++ b/exchanges/binance/binance_websocket.go @@ -564,7 +564,7 @@ func (b *Binance) Unsubscribe(channels subscription.List) error { // manageSubs subscribes or unsubscribes from a list of subscriptions func (b *Binance) manageSubs(op string, subs subscription.List) error { if op == wsSubscribeMethod { - if err := b.Websocket.AddSubscriptions(subs...); err != nil { // Note: AddSubscription will set state to subscribing + if err := b.Websocket.AddSubscriptions(b.Websocket.Conn, subs...); err != nil { // Note: AddSubscription will set state to subscribing return err } } else { @@ -593,7 +593,7 @@ func (b *Binance) manageSubs(op string, subs subscription.List) error { b.Websocket.DataHandler <- err if op == wsSubscribeMethod { - if err2 := b.Websocket.RemoveSubscriptions(nil, subs...); err2 != nil { + if err2 := b.Websocket.RemoveSubscriptions(b.Websocket.Conn, subs...); err2 != nil { err = common.AppendError(err, err2) } } @@ -601,7 +601,7 @@ func (b *Binance) manageSubs(op string, subs subscription.List) error { if op == wsSubscribeMethod { err = common.AppendError(err, subs.SetStates(subscription.SubscribedState)) } else { - err = b.Websocket.RemoveSubscriptions(nil, subs...) + err = b.Websocket.RemoveSubscriptions(b.Websocket.Conn, subs...) } } diff --git a/exchanges/binanceus/binanceus_websocket.go b/exchanges/binanceus/binanceus_websocket.go index 6335bddcab6..52c4a9deeec 100644 --- a/exchanges/binanceus/binanceus_websocket.go +++ b/exchanges/binanceus/binanceus_websocket.go @@ -614,7 +614,7 @@ func (bi *Binanceus) Unsubscribe(channelsToUnsubscribe subscription.List) error return err } } - return bi.Websocket.RemoveSubscriptions(nil, channelsToUnsubscribe...) + return bi.Websocket.RemoveSubscriptions(bi.Websocket.Conn, channelsToUnsubscribe...) } func (bi *Binanceus) setupOrderbookManager() { diff --git a/exchanges/bitfinex/bitfinex_test.go b/exchanges/bitfinex/bitfinex_test.go index 30a2c85b35f..ed8891cf1db 100644 --- a/exchanges/bitfinex/bitfinex_test.go +++ b/exchanges/bitfinex/bitfinex_test.go @@ -1325,7 +1325,7 @@ func TestWsSubscribedResponse(t *testing.T) { assert.ErrorContains(t, err, "waiter1", "Should error containing subID if") } - err = b.Websocket.AddSubscriptions(&subscription.Subscription{Key: "waiter1"}) + err = b.Websocket.AddSubscriptions(b.Websocket.Conn, &subscription.Subscription{Key: "waiter1"}) require.NoError(t, err, "AddSubscriptions must not error") err = b.wsHandleData([]byte(`{"event":"subscribed","channel":"ticker","chanId":224555,"subId":"waiter1","symbol":"tBTCUSD","pair":"BTCUSD"}`)) assert.NoError(t, err, "wsHandleData should not error") @@ -1339,7 +1339,7 @@ func TestWsSubscribedResponse(t *testing.T) { } func TestWsOrderBook(t *testing.T) { - err := b.Websocket.AddSubscriptions(&subscription.Subscription{Key: 23405, Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsBook}) + err := b.Websocket.AddSubscriptions(b.Websocket.Conn, &subscription.Subscription{Key: 23405, Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsBook}) require.NoError(t, err, "AddSubscriptions must not error") pressXToJSON := `[23405,[[38334303613,9348.8,0.53],[38334308111,9348.8,5.98979404],[38331335157,9344.1,1.28965787],[38334302803,9343.8,0.08230094],[38334279092,9343,0.8],[38334307036,9342.938663676,0.8],[38332749107,9342.9,0.2],[38332277330,9342.8,0.85],[38329406786,9342,0.1432012],[38332841570,9341.947288638,0.3],[38332163238,9341.7,0.3],[38334303384,9341.6,0.324],[38332464840,9341.4,0.5],[38331935870,9341.2,0.5],[38334312082,9340.9,0.02126899],[38334261292,9340.8,0.26763],[38334138680,9340.625455254,0.12],[38333896802,9339.8,0.85],[38331627527,9338.9,1.57863959],[38334186713,9338.9,0.26769],[38334305819,9338.8,2.999],[38334211180,9338.75285796,3.999],[38334310699,9337.8,0.10679883],[38334307414,9337.5,1],[38334179822,9337.1,0.26773],[38334306600,9336.659955102,1.79],[38334299667,9336.6,1.1],[38334306452,9336.6,0.13979771],[38325672859,9336.3,1.25],[38334311646,9336.2,1],[38334258509,9336.1,0.37],[38334310592,9336,1.79],[38334310378,9335.6,1.43],[38334132444,9335.2,0.26777],[38331367325,9335,0.07],[38334310703,9335,0.10680562],[38334298209,9334.7,0.08757301],[38334304857,9334.456899462,0.291],[38334309940,9334.088390727,0.0725],[38334310377,9333.7,1.2868],[38334297615,9333.607784,0.1108],[38334095188,9333.3,0.26785],[38334228913,9332.7,0.40861186],[38334300526,9332.363996604,0.3884],[38334310701,9332.2,0.10680562],[38334303548,9332.005382871,0.07],[38334311798,9331.8,0.41285228],[38334301012,9331.7,1.7952],[38334089877,9331.4,0.2679],[38321942150,9331.2,0.2],[38334310670,9330,1.069],[38334063096,9329.6,0.26796],[38334310700,9329.4,0.10680562],[38334310404,9329.3,1],[38334281630,9329.1,6.57150597],[38334036864,9327.7,0.26801],[38334310702,9326.6,0.10680562],[38334311799,9326.1,0.50220625],[38334164163,9326,0.219638],[38334309722,9326,1.5],[38333051682,9325.8,0.26807],[38334302027,9325.7,0.75],[38334203435,9325.366592,0.32397696],[38321967613,9325,0.05],[38334298787,9324.9,0.3],[38334301719,9324.8,3.6227592],[38331316716,9324.763454646,0.71442],[38334310698,9323.8,0.10680562],[38334035499,9323.7,0.23431017],[38334223472,9322.670551788,0.42150603],[38334163459,9322.560399006,0.143967],[38321825171,9320.8,2],[38334075805,9320.467496148,0.30772633],[38334075800,9319.916732238,0.61457592],[38333682302,9319.7,0.0011],[38331323088,9319.116771762,0.12913],[38333677480,9319,0.0199],[38334277797,9318.6,0.89],[38325235155,9318.041088,1.20249],[38334310910,9317.82382938,1.79],[38334311811,9317.2,0.61079138],[38334311812,9317.2,0.71937652],[38333298214,9317.1,50],[38334306359,9317,1.79],[38325531545,9316.382823951,0.21263],[38333727253,9316.3,0.02316372],[38333298213,9316.1,45],[38333836479,9316,2.135],[38324520465,9315.9,2.7681],[38334307411,9315.5,1],[38330313617,9315.3,0.84455],[38334077770,9315.294024,0.01248397],[38334286663,9315.294024,1],[38325533762,9315.290315394,2.40498],[38334310018,9315.2,3],[38333682617,9314.6,0.0011],[38334304794,9314.6,0.76364676],[38334304798,9314.3,0.69242113],[38332915733,9313.8,0.0199],[38334084411,9312.8,1],[38334311893,9350.1,-1.015],[38334302734,9350.3,-0.26737],[38334300732,9350.8,-5.2],[38333957619,9351,-0.90677089],[38334300521,9351,-1.6457],[38334301600,9351.012829557,-0.0523],[38334308878,9351.7,-2.5],[38334299570,9351.921544,-0.1015],[38334279367,9352.1,-0.26732],[38334299569,9352.411802928,-0.4036],[38334202773,9353.4,-0.02139404],[38333918472,9353.7,-1.96412776],[38334278782,9354,-0.26731],[38334278606,9355,-1.2785],[38334302105,9355.439221251,-0.79191542],[38313897370,9355.569409242,-0.43363],[38334292995,9355.584296,-0.0979],[38334216989,9355.8,-0.03686414],[38333894025,9355.9,-0.26721],[38334293798,9355.936691952,-0.4311],[38331159479,9356,-0.4204022],[38333918888,9356.1,-1.10885563],[38334298205,9356.4,-0.20124428],[38328427481,9356.5,-0.1],[38333343289,9356.6,-0.41034213],[38334297205,9356.6,-0.08835018],[38334277927,9356.741101161,-0.0737],[38334311645,9356.8,-0.5],[38334309002,9356.9,-5],[38334309736,9357,-0.10680107],[38334306448,9357.4,-0.18645275],[38333693302,9357.7,-0.2672],[38332815159,9357.8,-0.0011],[38331239824,9358.2,-0.02],[38334271608,9358.3,-2.999],[38334311971,9358.4,-0.55],[38333919260,9358.5,-1.9972841],[38334265365,9358.5,-1.7841],[38334277960,9359,-3],[38334274601,9359.020969848,-3],[38326848839,9359.1,-0.84],[38334291080,9359.247048,-0.16199869],[38326848844,9359.4,-1.84],[38333680200,9359.6,-0.26713],[38331326606,9359.8,-0.84454],[38334309738,9359.8,-0.10680107],[38331314707,9359.9,-0.2],[38333919803,9360.9,-1.41177599],[38323651149,9361.33417827,-0.71442],[38333656906,9361.5,-0.26705],[38334035500,9361.5,-0.40861586],[38334091886,9362.4,-6.85940815],[38334269617,9362.5,-4],[38323629409,9362.545858872,-2.40497],[38334309737,9362.7,-0.10680107],[38334312380,9362.7,-3],[38325280830,9362.8,-1.75123],[38326622800,9362.8,-1.05145],[38333175230,9363,-0.0011],[38326848745,9363.2,-0.79],[38334308960,9363.206775564,-0.12],[38333920234,9363.3,-1.25318113],[38326848843,9363.4,-1.29],[38331239823,9363.4,-0.02],[38333209613,9363.4,-0.26719],[38334299964,9364,-0.05583123],[38323470224,9364.161816648,-0.12912],[38334284711,9365,-0.21346019],[38334299594,9365,-2.6757062],[38323211816,9365.073132585,-0.21262],[38334312456,9365.1,-0.11167861],[38333209612,9365.2,-0.26719],[38327770474,9365.3,-0.0073],[38334298788,9365.3,-0.3],[38334075803,9365.409831204,-0.30772637],[38334309740,9365.5,-0.10680107],[38326608767,9365.7,-2.76809],[38333920657,9365.7,-1.25848083],[38329594226,9366.6,-0.02587],[38334311813,9366.7,-4.72290945],[38316386301,9367.39258128,-2.37581],[38334302026,9367.4,-4.5],[38334228915,9367.9,-0.81725458],[38333921381,9368.1,-1.72213641],[38333175678,9368.2,-0.0011],[38334301150,9368.2,-2.654604],[38334297208,9368.3,-0.78036466],[38334309739,9368.3,-0.10680107],[38331227515,9368.7,-0.02],[38331184470,9369,-0.003975],[38334203436,9369.319616,-0.32397695],[38334269964,9369.7,-0.5],[38328386732,9370,-4.11759935],[38332719555,9370,-0.025],[38333921935,9370.5,-1.2224398],[38334258511,9370.5,-0.35],[38326848842,9370.8,-0.34],[38333985038,9370.9,-0.8551502],[38334283018,9370.9,-1],[38326848744,9371,-1.34]],5]` err = b.wsHandleData([]byte(pressXToJSON)) @@ -1357,7 +1357,7 @@ func TestWsOrderBook(t *testing.T) { } func TestWsTradeResponse(t *testing.T) { - err := b.Websocket.AddSubscriptions(&subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsTrades, Key: 18788}) + err := b.Websocket.AddSubscriptions(b.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsTrades, Key: 18788}) require.NoError(t, err, "AddSubscriptions must not error") pressXToJSON := `[18788,[[412685577,1580268444802,11.1998,176.3],[412685575,1580268444802,5,176.29952759],[412685574,1580268374717,1.99069999,176.41],[412685573,1580268374717,1.00930001,176.41],[412685572,1580268358760,0.9907,176.47],[412685571,1580268324362,0.5505,176.44],[412685570,1580268297270,-0.39040819,176.39],[412685568,1580268297270,-0.39780162,176.46475676],[412685567,1580268283470,-0.09,176.41],[412685566,1580268256536,-2.31310783,176.48],[412685565,1580268256536,-0.59669217,176.49],[412685564,1580268256536,-0.9902,176.49],[412685562,1580268194474,0.9902,176.55],[412685561,1580268186215,0.1,176.6],[412685560,1580268185964,-2.17096773,176.5],[412685559,1580268185964,-1.82903227,176.51],[412685558,1580268181215,2.098914,176.53],[412685557,1580268169844,16.7302,176.55],[412685556,1580268169844,3.25,176.54],[412685555,1580268155725,0.23576115,176.45],[412685553,1580268155725,3,176.44596249],[412685552,1580268155725,3.25,176.44],[412685551,1580268155725,5,176.44],[412685550,1580268155725,0.65830078,176.41],[412685549,1580268155725,0.45063807,176.41],[412685548,1580268153825,-0.67604704,176.39],[412685547,1580268145713,2.5883,176.41],[412685543,1580268087513,12.92927,176.33],[412685542,1580268087513,0.40083,176.33],[412685533,1580268005756,-0.17096773,176.32]]]` err = b.wsHandleData([]byte(pressXToJSON)) @@ -1367,7 +1367,7 @@ func TestWsTradeResponse(t *testing.T) { } func TestWsTickerResponse(t *testing.T) { - err := b.Websocket.AddSubscriptions(&subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsTicker, Key: 11534}) + err := b.Websocket.AddSubscriptions(b.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsTicker, Key: 11534}) require.NoError(t, err, "AddSubscriptions must not error") pressXToJSON := `[11534,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` err = b.wsHandleData([]byte(pressXToJSON)) @@ -1378,7 +1378,7 @@ func TestWsTickerResponse(t *testing.T) { if err != nil { t.Error(err) } - err = b.Websocket.AddSubscriptions(&subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{pair}, Channel: wsTicker, Key: 123412}) + err = b.Websocket.AddSubscriptions(b.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{pair}, Channel: wsTicker, Key: 123412}) require.NoError(t, err, "AddSubscriptions must not error") pressXToJSON = `[123412,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` err = b.wsHandleData([]byte(pressXToJSON)) @@ -1389,7 +1389,7 @@ func TestWsTickerResponse(t *testing.T) { if err != nil { t.Error(err) } - err = b.Websocket.AddSubscriptions(&subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{pair}, Channel: wsTicker, Key: 123413}) + err = b.Websocket.AddSubscriptions(b.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{pair}, Channel: wsTicker, Key: 123413}) require.NoError(t, err, "AddSubscriptions must not error") pressXToJSON = `[123413,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` err = b.wsHandleData([]byte(pressXToJSON)) @@ -1400,7 +1400,7 @@ func TestWsTickerResponse(t *testing.T) { if err != nil { t.Error(err) } - err = b.Websocket.AddSubscriptions(&subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{pair}, Channel: wsTicker, Key: 123414}) + err = b.Websocket.AddSubscriptions(b.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{pair}, Channel: wsTicker, Key: 123414}) require.NoError(t, err, "AddSubscriptions must not error") pressXToJSON = `[123414,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` err = b.wsHandleData([]byte(pressXToJSON)) @@ -1410,7 +1410,7 @@ func TestWsTickerResponse(t *testing.T) { } func TestWsCandleResponse(t *testing.T) { - err := b.Websocket.AddSubscriptions(&subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsCandles, Key: 343351}) + err := b.Websocket.AddSubscriptions(b.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsCandles, Key: 343351}) require.NoError(t, err, "AddSubscriptions must not error") pressXToJSON := `[343351,[[1574698260000,7379.785503,7383.8,7388.3,7379.785503,1.68829482]]]` err = b.wsHandleData([]byte(pressXToJSON)) diff --git a/exchanges/bitfinex/bitfinex_websocket.go b/exchanges/bitfinex/bitfinex_websocket.go index 4fec8b42086..448685d89e7 100644 --- a/exchanges/bitfinex/bitfinex_websocket.go +++ b/exchanges/bitfinex/bitfinex_websocket.go @@ -511,7 +511,7 @@ func (b *Bitfinex) handleWSSubscribed(respRaw []byte) error { c.Key = int(chanID) // subscribeToChan removes the old subID keyed Subscription - if err := b.Websocket.AddSuccessfulSubscriptions(nil, c); err != nil { + if err := b.Websocket.AddSuccessfulSubscriptions(b.Websocket.Conn, c); err != nil { return fmt.Errorf("%w: %w subID: %s", stream.ErrSubscriptionFailure, err, subID) } @@ -1747,13 +1747,13 @@ func (b *Bitfinex) subscribeToChan(chans subscription.List) error { // Add a temporary Key so we can find this Sub when we get the resp without delay or context switch // Otherwise we might drop the first messages after the subscribed resp c.Key = subID // Note subID string type avoids conflicts with later chanID key - if err = b.Websocket.AddSubscriptions(c); err != nil { + if err = b.Websocket.AddSubscriptions(b.Websocket.Conn, c); err != nil { return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrSubscriptionFailure, c.Channel, c.Pairs, err) } // Always remove the temporary subscription keyed by subID defer func() { - _ = b.Websocket.RemoveSubscriptions(nil, c) + _ = b.Websocket.RemoveSubscriptions(b.Websocket.Conn, c) }() respRaw, err := b.Websocket.Conn.SendMessageReturnResponse("subscribe:"+subID, req) @@ -1860,7 +1860,7 @@ func (b *Bitfinex) unsubscribeFromChan(chans subscription.List) error { return wErr } - return b.Websocket.RemoveSubscriptions(nil, c) + return b.Websocket.RemoveSubscriptions(b.Websocket.Conn, c) } // getErrResp takes a json response string and looks for an error event type diff --git a/exchanges/bithumb/bithumb_websocket.go b/exchanges/bithumb/bithumb_websocket.go index 01c61cdeaff..7efc82147db 100644 --- a/exchanges/bithumb/bithumb_websocket.go +++ b/exchanges/bithumb/bithumb_websocket.go @@ -205,7 +205,7 @@ func (b *Bithumb) Subscribe(channelsToSubscribe subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(req) if err == nil { - err = b.Websocket.AddSuccessfulSubscriptions(nil, s) + err = b.Websocket.AddSuccessfulSubscriptions(b.Websocket.Conn, s) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/bitmex/bitmex_websocket.go b/exchanges/bitmex/bitmex_websocket.go index dce46c9d4e0..61ab3f78143 100644 --- a/exchanges/bitmex/bitmex_websocket.go +++ b/exchanges/bitmex/bitmex_websocket.go @@ -601,7 +601,7 @@ func (b *Bitmex) Subscribe(subs subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(req) if err == nil { - err = b.Websocket.AddSuccessfulSubscriptions(nil, subs...) + err = b.Websocket.AddSuccessfulSubscriptions(b.Websocket.Conn, subs...) } return err } @@ -620,7 +620,7 @@ func (b *Bitmex) Unsubscribe(subs subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(req) if err == nil { - err = b.Websocket.RemoveSubscriptions(nil, subs...) + err = b.Websocket.RemoveSubscriptions(b.Websocket.Conn, subs...) } return err } diff --git a/exchanges/bitstamp/bitstamp_websocket.go b/exchanges/bitstamp/bitstamp_websocket.go index 275090be2d3..632fd4f16f2 100644 --- a/exchanges/bitstamp/bitstamp_websocket.go +++ b/exchanges/bitstamp/bitstamp_websocket.go @@ -294,7 +294,7 @@ func (b *Bitstamp) Subscribe(channelsToSubscribe subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(req) if err == nil { - err = b.Websocket.AddSuccessfulSubscriptions(nil, s) + err = b.Websocket.AddSuccessfulSubscriptions(b.Websocket.Conn, s) } if err != nil { errs = common.AppendError(errs, err) @@ -316,7 +316,7 @@ func (b *Bitstamp) Unsubscribe(channelsToUnsubscribe subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(req) if err == nil { - err = b.Websocket.RemoveSubscriptions(nil, s) + err = b.Websocket.RemoveSubscriptions(b.Websocket.Conn, s) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/btcmarkets/btcmarkets_websocket.go b/exchanges/btcmarkets/btcmarkets_websocket.go index 4c817a3bc25..a0eea01335c 100644 --- a/exchanges/btcmarkets/btcmarkets_websocket.go +++ b/exchanges/btcmarkets/btcmarkets_websocket.go @@ -376,7 +376,7 @@ func (b *BTCMarkets) Subscribe(subs subscription.List) error { err := b.Websocket.Conn.SendJSONMessage(r) if err == nil { - err = b.Websocket.AddSuccessfulSubscriptions(nil, s) + err = b.Websocket.AddSuccessfulSubscriptions(b.Websocket.Conn, s) } if err != nil { errs = common.AppendError(errs, err) @@ -416,7 +416,7 @@ func (b *BTCMarkets) Unsubscribe(subs subscription.List) error { err := b.Websocket.Conn.SendJSONMessage(req) if err == nil { - err = b.Websocket.RemoveSubscriptions(nil, s) + err = b.Websocket.RemoveSubscriptions(b.Websocket.Conn, s) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/btse/btse_websocket.go b/exchanges/btse/btse_websocket.go index f5af037b10c..2576c0251ee 100644 --- a/exchanges/btse/btse_websocket.go +++ b/exchanges/btse/btse_websocket.go @@ -394,7 +394,7 @@ func (b *BTSE) Subscribe(channelsToSubscribe subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(sub) if err == nil { - err = b.Websocket.AddSuccessfulSubscriptions(nil, channelsToSubscribe...) + err = b.Websocket.AddSuccessfulSubscriptions(b.Websocket.Conn, channelsToSubscribe...) } return err } @@ -409,7 +409,7 @@ func (b *BTSE) Unsubscribe(channelsToUnsubscribe subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(unSub) if err == nil { - err = b.Websocket.RemoveSubscriptions(nil, channelsToUnsubscribe...) + err = b.Websocket.RemoveSubscriptions(b.Websocket.Conn, channelsToUnsubscribe...) } return err } diff --git a/exchanges/coinbasepro/coinbasepro_websocket.go b/exchanges/coinbasepro/coinbasepro_websocket.go index 2476be67db2..99f66869839 100644 --- a/exchanges/coinbasepro/coinbasepro_websocket.go +++ b/exchanges/coinbasepro/coinbasepro_websocket.go @@ -425,7 +425,7 @@ func (c *CoinbasePro) Subscribe(subs subscription.List) error { } err := c.Websocket.Conn.SendJSONMessage(r) if err == nil { - err = c.Websocket.AddSuccessfulSubscriptions(nil, subs...) + err = c.Websocket.AddSuccessfulSubscriptions(c.Websocket.Conn, subs...) } return err } @@ -461,7 +461,7 @@ func (c *CoinbasePro) Unsubscribe(subs subscription.List) error { } err := c.Websocket.Conn.SendJSONMessage(r) if err == nil { - err = c.Websocket.RemoveSubscriptions(nil, subs...) + err = c.Websocket.RemoveSubscriptions(c.Websocket.Conn, subs...) } return err } diff --git a/exchanges/coinut/coinut_websocket.go b/exchanges/coinut/coinut_websocket.go index 7f1bb3e139f..ad016a1c336 100644 --- a/exchanges/coinut/coinut_websocket.go +++ b/exchanges/coinut/coinut_websocket.go @@ -620,7 +620,7 @@ func (c *COINUT) Subscribe(subs subscription.List) error { } err = c.Websocket.Conn.SendJSONMessage(subscribe) if err == nil { - err = c.Websocket.AddSuccessfulSubscriptions(nil, s) + err = c.Websocket.AddSuccessfulSubscriptions(c.Websocket.Conn, s) } if err != nil { errs = common.AppendError(errs, err) @@ -663,7 +663,7 @@ func (c *COINUT) Unsubscribe(channelToUnsubscribe subscription.List) error { case len(val) == 0, val[0] != "OK": err = common.AppendError(errs, fmt.Errorf("%v unsubscribe failed for channel %v", c.Name, s.Channel)) default: - err = c.Websocket.RemoveSubscriptions(nil, s) + err = c.Websocket.RemoveSubscriptions(c.Websocket.Conn, s) } } if err != nil { diff --git a/exchanges/gemini/gemini_websocket.go b/exchanges/gemini/gemini_websocket.go index dfb24b2cadf..a1255c1f40d 100644 --- a/exchanges/gemini/gemini_websocket.go +++ b/exchanges/gemini/gemini_websocket.go @@ -117,10 +117,10 @@ func (g *Gemini) manageSubs(subs subscription.List, op wsSubOp) error { } if op == wsUnsubscribeOp { - return g.Websocket.RemoveSubscriptions(nil, subs...) + return g.Websocket.RemoveSubscriptions(g.Websocket.Conn, subs...) } - return g.Websocket.AddSuccessfulSubscriptions(nil, subs...) + return g.Websocket.AddSuccessfulSubscriptions(g.Websocket.Conn, subs...) } // WsAuth will connect to Gemini's secure endpoint diff --git a/exchanges/hitbtc/hitbtc_websocket.go b/exchanges/hitbtc/hitbtc_websocket.go index 38d0f6455bb..80a64edea68 100644 --- a/exchanges/hitbtc/hitbtc_websocket.go +++ b/exchanges/hitbtc/hitbtc_websocket.go @@ -526,7 +526,7 @@ func (h *HitBTC) Subscribe(channelsToSubscribe subscription.List) error { err := h.Websocket.Conn.SendJSONMessage(r) if err == nil { - err = h.Websocket.AddSuccessfulSubscriptions(nil, s) + err = h.Websocket.AddSuccessfulSubscriptions(h.Websocket.Conn, s) } if err != nil { errs = common.AppendError(errs, err) @@ -562,7 +562,7 @@ func (h *HitBTC) Unsubscribe(subs subscription.List) error { err := h.Websocket.Conn.SendJSONMessage(r) if err == nil { - err = h.Websocket.RemoveSubscriptions(nil, s) + err = h.Websocket.RemoveSubscriptions(h.Websocket.Conn, s) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/huobi/huobi_websocket.go b/exchanges/huobi/huobi_websocket.go index 6ae741854e4..9c11750aaf2 100644 --- a/exchanges/huobi/huobi_websocket.go +++ b/exchanges/huobi/huobi_websocket.go @@ -570,7 +570,7 @@ func (h *HUOBI) Subscribe(channelsToSubscribe subscription.List) error { }) } if err == nil { - err = h.Websocket.AddSuccessfulSubscriptions(nil, channelsToSubscribe[i]) + err = h.Websocket.AddSuccessfulSubscriptions(h.Websocket.Conn, channelsToSubscribe[i]) } if err != nil { errs = common.AppendError(errs, err) @@ -604,7 +604,7 @@ func (h *HUOBI) Unsubscribe(channelsToUnsubscribe subscription.List) error { }) } if err == nil { - err = h.Websocket.RemoveSubscriptions(nil, channelsToUnsubscribe[i]) + err = h.Websocket.RemoveSubscriptions(h.Websocket.Conn, channelsToUnsubscribe[i]) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/kraken/kraken_websocket.go b/exchanges/kraken/kraken_websocket.go index b1e41d31fb8..883ebe9eb3d 100644 --- a/exchanges/kraken/kraken_websocket.go +++ b/exchanges/kraken/kraken_websocket.go @@ -1229,13 +1229,16 @@ channels: for _, subs := range subscriptions { for i := range *subs { var err error + var conn stream.Connection if common.StringDataContains(authenticatedChannels, (*subs)[i].Subscription.Name) { _, err = k.Websocket.AuthConn.SendMessageReturnResponse((*subs)[i].RequestID, (*subs)[i]) + conn = k.Websocket.AuthConn } else { _, err = k.Websocket.Conn.SendMessageReturnResponse((*subs)[i].RequestID, (*subs)[i]) + conn = k.Websocket.Conn } if err == nil { - err = k.Websocket.AddSuccessfulSubscriptions(nil, (*subs)[i].Channels...) + err = k.Websocket.AddSuccessfulSubscriptions(conn, (*subs)[i].Channels...) } if err != nil { errs = common.AppendError(errs, err) @@ -1288,13 +1291,16 @@ channels: var errs error for i := range unsubs { var err error + var conn stream.Connection if common.StringDataContains(authenticatedChannels, unsubs[i].Subscription.Name) { _, err = k.Websocket.AuthConn.SendMessageReturnResponse(unsubs[i].RequestID, unsubs[i]) + conn = k.Websocket.AuthConn } else { _, err = k.Websocket.Conn.SendMessageReturnResponse(unsubs[i].RequestID, unsubs[i]) + conn = k.Websocket.Conn } if err == nil { - err = k.Websocket.RemoveSubscriptions(nil, unsubs[i].Channels...) + err = k.Websocket.RemoveSubscriptions(conn, unsubs[i].Channels...) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/kucoin/kucoin_websocket.go b/exchanges/kucoin/kucoin_websocket.go index 8feb1787b6f..21fcbe288e3 100644 --- a/exchanges/kucoin/kucoin_websocket.go +++ b/exchanges/kucoin/kucoin_websocket.go @@ -1007,9 +1007,9 @@ func (ku *Kucoin) manageSubscriptions(subs subscription.List, operation string) errs = common.AppendError(errs, fmt.Errorf("%w: %s from %s", errInvalidMsgType, rType, respRaw)) default: if operation == "unsubscribe" { - err = ku.Websocket.RemoveSubscriptions(nil, s) + err = ku.Websocket.RemoveSubscriptions(ku.Websocket.Conn, s) } else { - err = ku.Websocket.AddSuccessfulSubscriptions(nil, s) + err = ku.Websocket.AddSuccessfulSubscriptions(ku.Websocket.Conn, s) if ku.Verbose { log.Debugf(log.ExchangeSys, "%s Subscribed to Channel: %s", ku.Name, s.Channel) } diff --git a/exchanges/okcoin/okcoin_websocket.go b/exchanges/okcoin/okcoin_websocket.go index 30a39aa9c06..d8d27dd36b1 100644 --- a/exchanges/okcoin/okcoin_websocket.go +++ b/exchanges/okcoin/okcoin_websocket.go @@ -931,15 +931,15 @@ func (o *Okcoin) manageSubscriptions(operation string, subs subscription.List) e if operation == "unsubscribe" { if authenticatedChannelSubscription { - err = o.Websocket.RemoveSubscriptions(nil, authChannels...) + err = o.Websocket.RemoveSubscriptions(o.Websocket.AuthConn, authChannels...) } else { - err = o.Websocket.RemoveSubscriptions(nil, channels...) + err = o.Websocket.RemoveSubscriptions(o.Websocket.Conn, channels...) } } else { if authenticatedChannelSubscription { - err = o.Websocket.AddSuccessfulSubscriptions(nil, authChannels...) + err = o.Websocket.AddSuccessfulSubscriptions(o.Websocket.AuthConn, authChannels...) } else { - err = o.Websocket.AddSuccessfulSubscriptions(nil, channels...) + err = o.Websocket.AddSuccessfulSubscriptions(o.Websocket.Conn, channels...) } } if err != nil { @@ -974,9 +974,9 @@ func (o *Okcoin) manageSubscriptions(operation string, subs subscription.List) e } } if operation == "unsubscribe" { - return o.Websocket.RemoveSubscriptions(nil, channels...) + return o.Websocket.RemoveSubscriptions(o.Websocket.Conn, channels...) } - return o.Websocket.AddSuccessfulSubscriptions(nil, channels...) + return o.Websocket.AddSuccessfulSubscriptions(o.Websocket.Conn, channels...) } // GetCandlesData represents a candlestick instances list. diff --git a/exchanges/okx/okx_websocket.go b/exchanges/okx/okx_websocket.go index 3ff3660c918..cba7e19f47d 100644 --- a/exchanges/okx/okx_websocket.go +++ b/exchanges/okx/okx_websocket.go @@ -486,9 +486,9 @@ func (ok *Okx) handleSubscription(operation string, subscriptions subscription.L return err } if operation == operationUnsubscribe { - err = ok.Websocket.RemoveSubscriptions(nil, channels...) + err = ok.Websocket.RemoveSubscriptions(ok.Websocket.AuthConn, channels...) } else { - err = ok.Websocket.AddSuccessfulSubscriptions(nil, channels...) + err = ok.Websocket.AddSuccessfulSubscriptions(ok.Websocket.AuthConn, channels...) } if err != nil { return err @@ -510,9 +510,9 @@ func (ok *Okx) handleSubscription(operation string, subscriptions subscription.L return err } if operation == operationUnsubscribe { - err = ok.Websocket.RemoveSubscriptions(nil, channels...) + err = ok.Websocket.RemoveSubscriptions(ok.Websocket.Conn, channels...) } else { - err = ok.Websocket.AddSuccessfulSubscriptions(nil, channels...) + err = ok.Websocket.AddSuccessfulSubscriptions(ok.Websocket.Conn, channels...) } if err != nil { return err @@ -538,10 +538,10 @@ func (ok *Okx) handleSubscription(operation string, subscriptions subscription.L channels = append(channels, authChannels...) if operation == operationUnsubscribe { - return ok.Websocket.RemoveSubscriptions(nil, channels...) + return ok.Websocket.RemoveSubscriptions(ok.Websocket.Conn, channels...) } - return ok.Websocket.AddSuccessfulSubscriptions(nil, channels...) + return ok.Websocket.AddSuccessfulSubscriptions(ok.Websocket.Conn, channels...) } // WsHandleData will read websocket raw data and pass to appropriate handler diff --git a/exchanges/poloniex/poloniex_websocket.go b/exchanges/poloniex/poloniex_websocket.go index ebce7241318..8e43780d81f 100644 --- a/exchanges/poloniex/poloniex_websocket.go +++ b/exchanges/poloniex/poloniex_websocket.go @@ -608,9 +608,9 @@ func (p *Poloniex) manageSubs(subs subscription.List, op wsOp) error { } if err == nil { if op == wsSubscribeOp { - err = p.Websocket.AddSuccessfulSubscriptions(nil, s) + err = p.Websocket.AddSuccessfulSubscriptions(p.Websocket.Conn, s) } else { - err = p.Websocket.RemoveSubscriptions(nil, s) + err = p.Websocket.RemoveSubscriptions(p.Websocket.Conn, s) } } if err != nil { diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 4e087f0bc9c..a68b56ed79f 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -1049,12 +1049,19 @@ func (w *Websocket) SubscribeToChannels(conn Connection, subs subscription.List) // AddSubscriptions adds subscriptions to the subscription store // Sets state to Subscribing unless the state is already set -func (w *Websocket) AddSubscriptions(subs ...*subscription.Subscription) error { +func (w *Websocket) AddSubscriptions(conn Connection, subs ...*subscription.Subscription) error { if w == nil { return fmt.Errorf("%w: AddSubscriptions called on nil Websocket", common.ErrNilPointer) } - if w.subscriptions == nil { - w.subscriptions = subscription.NewStore() + var subscriptionStore **subscription.Store + if candidate, ok := w.connections[conn]; ok { + subscriptionStore = &candidate.Subscriptions + } else { + subscriptionStore = &w.subscriptions + } + + if *subscriptionStore == nil { + *subscriptionStore = subscription.NewStore() } var errs error for _, s := range subs { @@ -1063,7 +1070,7 @@ func (w *Websocket) AddSubscriptions(subs ...*subscription.Subscription) error { errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s)) } } - if err := w.subscriptions.Add(s); err != nil { + if err := (*subscriptionStore).Add(s); err != nil { errs = common.AppendError(errs, err) } } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 92edd545350..c6b661d5303 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -674,9 +674,9 @@ func TestSubscriptions(t *testing.T) { w := new(Websocket) // Do not use NewWebsocket; We want to exercise w.subs == nil assert.ErrorIs(t, (*Websocket)(nil).AddSubscriptions(nil), common.ErrNilPointer, "Should error correctly when nil websocket") s := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel} - require.NoError(t, w.AddSubscriptions(s), "Adding first subscription should not error") + require.NoError(t, w.AddSubscriptions(nil, s), "Adding first subscription should not error") assert.Same(t, s, w.GetSubscription(42), "Get Subscription should retrieve the same subscription") - assert.ErrorIs(t, w.AddSubscriptions(s), subscription.ErrDuplicate, "Adding same subscription should return error") + assert.ErrorIs(t, w.AddSubscriptions(nil, s), subscription.ErrDuplicate, "Adding same subscription should return error") assert.Equal(t, subscription.SubscribingState, s.State(), "Should set state to Subscribing") err := w.RemoveSubscriptions(nil, s) @@ -685,7 +685,7 @@ func TestSubscriptions(t *testing.T) { assert.Equal(t, subscription.UnsubscribedState, s.State(), "Should set state to Unsubscribed") require.NoError(t, s.SetState(subscription.ResubscribingState), "SetState must not error") - require.NoError(t, w.AddSubscriptions(s), "Adding first subscription should not error") + require.NoError(t, w.AddSubscriptions(nil, s), "Adding first subscription should not error") assert.Equal(t, subscription.ResubscribingState, s.State(), "Should not change resubscribing state") } @@ -732,7 +732,7 @@ func TestGetSubscription(t *testing.T) { w := NewWebsocket() assert.Nil(t, w.GetSubscription(nil), "GetSubscription with a nil key should return nil") s := &subscription.Subscription{Key: 42, Channel: "hello3"} - require.NoError(t, w.AddSubscriptions(s), "AddSubscriptions must not error") + require.NoError(t, w.AddSubscriptions(nil, s), "AddSubscriptions must not error") assert.Same(t, s, w.GetSubscription(42), "GetSubscription should delegate to the store") } @@ -746,7 +746,7 @@ func TestGetSubscriptions(t *testing.T) { {Key: 42, Channel: "hello3"}, {Key: 45, Channel: "hello4"}, } - err := w.AddSubscriptions(s...) + err := w.AddSubscriptions(nil, s...) require.NoError(t, err, "AddSubscriptions must not error") assert.ElementsMatch(t, s, w.GetSubscriptions(), "GetSubscriptions should return the correct channel details") } @@ -1090,7 +1090,7 @@ func TestGetChannelDifference(t *testing.T) { subs, unsubs := w.GetChannelDifference(nil, subscription.List{{Channel: subscription.CandlesChannel}}) require.Equal(t, 1, len(subs), "Should get the correct number of subs") require.Empty(t, unsubs, "Should get no unsubs") - require.NoError(t, w.AddSubscriptions(subs...), "AddSubscriptions must not error") + require.NoError(t, w.AddSubscriptions(nil, subs...), "AddSubscriptions must not error") subs, unsubs = w.GetChannelDifference(nil, subscription.List{{Channel: subscription.TickerChannel}}) require.Equal(t, 1, len(subs), "Should get the correct number of subs") assert.Equal(t, 1, len(unsubs), "Should get the correct number of unsubs") @@ -1219,7 +1219,7 @@ func TestFlushChannels(t *testing.T) { w.GenerateSubs = newgen.generateSubs subs, err := w.GenerateSubs() require.NoError(t, err, "GenerateSubs must not error") - require.NoError(t, w.AddSubscriptions(subs...), "AddSubscriptions must not error") + require.NoError(t, w.AddSubscriptions(nil, subs...), "AddSubscriptions must not error") err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") w.features.FullPayloadSubscribe = false From 16b0e2275eee95309d67e31ac6245269bb8b0458 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Tue, 16 Jul 2024 14:48:13 +1000 Subject: [PATCH 013/138] clean up nils --- docs/ADD_NEW_EXCHANGE.md | 2 +- exchanges/binanceus/binanceus_websocket.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/ADD_NEW_EXCHANGE.md b/docs/ADD_NEW_EXCHANGE.md index 9a4ea88288f..614b7dbb2cb 100644 --- a/docs/ADD_NEW_EXCHANGE.md +++ b/docs/ADD_NEW_EXCHANGE.md @@ -837,7 +837,7 @@ channels: continue } // When we have a successful subscription, we can alert our internal management system of the success. - f.Websocket.AddSuccessfulSubscriptions(nil, channelsToSubscribe[i]) + f.Websocket.AddSuccessfulSubscriptions(f.Websocket.Conn, channelsToSubscribe[i]) } return errs } diff --git a/exchanges/binanceus/binanceus_websocket.go b/exchanges/binanceus/binanceus_websocket.go index 52c4a9deeec..8846dec9f7a 100644 --- a/exchanges/binanceus/binanceus_websocket.go +++ b/exchanges/binanceus/binanceus_websocket.go @@ -590,7 +590,7 @@ func (bi *Binanceus) Subscribe(channelsToSubscribe subscription.List) error { return err } } - return bi.Websocket.AddSuccessfulSubscriptions(nil, channelsToSubscribe...) + return bi.Websocket.AddSuccessfulSubscriptions(bi.Websocket.Conn, channelsToSubscribe...) } // Unsubscribe unsubscribes from a set of channels From 32252b2802a3031d5d2a55e4b162b733a2d625b8 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Tue, 16 Jul 2024 15:10:58 +1000 Subject: [PATCH 014/138] cya nils --- exchanges/bitfinex/bitfinex_websocket.go | 2 +- exchanges/kraken/kraken_websocket.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/exchanges/bitfinex/bitfinex_websocket.go b/exchanges/bitfinex/bitfinex_websocket.go index 448685d89e7..7bb1b044c13 100644 --- a/exchanges/bitfinex/bitfinex_websocket.go +++ b/exchanges/bitfinex/bitfinex_websocket.go @@ -1660,7 +1660,7 @@ func (b *Bitfinex) resubOrderbook(c *subscription.Subscription) error { // Resub will block so we have to do this in a goro go func() { - if err := b.Websocket.ResubscribeToChannel(nil, c); err != nil { + if err := b.Websocket.ResubscribeToChannel(b.Websocket.Conn, c); err != nil { log.Errorf(log.ExchangeSys, "%s error resubscribing orderbook: %v", b.Name, err) } }() diff --git a/exchanges/kraken/kraken_websocket.go b/exchanges/kraken/kraken_websocket.go index 883ebe9eb3d..27385d3f8b0 100644 --- a/exchanges/kraken/kraken_websocket.go +++ b/exchanges/kraken/kraken_websocket.go @@ -800,7 +800,7 @@ func (k *Kraken) wsProcessOrderBook(channelData *WebsocketChannelData, data map[ go func(resub *subscription.Subscription) { // This was locking the main websocket reader routine and a // backlog occurred. So put this into it's own go routine. - errResub := k.Websocket.ResubscribeToChannel(nil, resub) + errResub := k.Websocket.ResubscribeToChannel(k.Websocket.Conn, resub) if errResub != nil { log.Errorf(log.WebsocketMgr, "resubscription failure for %v: %v", From 7c5d9c3821891f2e64847a60950c86055c54d25e Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 17 Jul 2024 13:13:17 +1000 Subject: [PATCH 015/138] don't need to set URL or check if its running --- exchanges/gateio/gateio_websocket.go | 3 --- exchanges/gateio/gateio_ws_delivery_futures.go | 6 +----- exchanges/gateio/gateio_ws_futures.go | 10 +--------- exchanges/gateio/gateio_ws_option.go | 10 +--------- 4 files changed, 3 insertions(+), 26 deletions(-) diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index f03b98f7682..9272eda6336 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -62,9 +62,6 @@ var fetchedCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsConnectSpot initiates a websocket connection func (g *Gateio) WsConnectSpot(ctx context.Context, conn stream.Connection) error { - if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return stream.ErrWebsocketNotEnabled - } err := g.CurrencyPairs.IsAssetEnabled(asset.Spot) if err != nil { return err diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_ws_delivery_futures.go index 0f70c82d32f..dca495dccc8 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_ws_delivery_futures.go @@ -40,15 +40,11 @@ var fetchedFuturesCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsDeliveryFuturesConnect initiates a websocket connection for delivery futures account func (g *Gateio) WsDeliveryFuturesConnect(ctx context.Context, conn stream.Connection) error { - if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return stream.ErrWebsocketNotEnabled - } err := g.CurrencyPairs.IsAssetEnabled(asset.DeliveryFutures) if err != nil { return err } - var dialer websocket.Dialer - err = conn.DialContext(ctx, &dialer, http.Header{}) + err = conn.DialContext(ctx, &websocket.Dialer{}, http.Header{}) if err != nil { return err } diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index f6764b3bcaf..f0ccbcb9a13 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -59,19 +59,11 @@ var defaultFuturesSubscriptions = []string{ // WsFuturesConnect initiates a websocket connection for futures account func (g *Gateio) WsFuturesConnect(ctx context.Context, conn stream.Connection) error { - if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return stream.ErrWebsocketNotEnabled - } err := g.CurrencyPairs.IsAssetEnabled(asset.Futures) if err != nil { return err } - var dialer websocket.Dialer - err = g.Websocket.SetWebsocketURL(futuresWebsocketUsdtURL, false, true) - if err != nil { - return err - } - err = conn.DialContext(ctx, &dialer, http.Header{}) + err = conn.DialContext(ctx, &websocket.Dialer{}, http.Header{}) if err != nil { return err } diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_ws_option.go index 5c5c7d514fb..16a6709584b 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_ws_option.go @@ -69,19 +69,11 @@ var fetchedOptionsCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsOptionsConnect initiates a websocket connection to options websocket endpoints. func (g *Gateio) WsOptionsConnect(ctx context.Context, conn stream.Connection) error { - if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return stream.ErrWebsocketNotEnabled - } err := g.CurrencyPairs.IsAssetEnabled(asset.Options) if err != nil { return err } - var dialer websocket.Dialer - err = g.Websocket.SetWebsocketURL(optionsWebsocketURL, false, true) - if err != nil { - return err - } - err = conn.DialContext(ctx, &dialer, http.Header{}) + err = conn.DialContext(ctx, &websocket.Dialer{}, http.Header{}) if err != nil { return err } From 2f93b64fa2dec021124e654fb6e79419b71e56f6 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 17 Jul 2024 15:38:29 +1000 Subject: [PATCH 016/138] stream match update --- exchanges/stream/stream_match.go | 67 +++++++++++---------------- exchanges/stream/stream_match_test.go | 10 +--- 2 files changed, 28 insertions(+), 49 deletions(-) diff --git a/exchanges/stream/stream_match.go b/exchanges/stream/stream_match.go index dc2d46b3ac2..412c08bb72b 100644 --- a/exchanges/stream/stream_match.go +++ b/exchanges/stream/stream_match.go @@ -5,11 +5,11 @@ import ( "sync" ) +var errSignatureCollision = errors.New("signature collision") + // NewMatch returns a new Match func NewMatch() *Match { - return &Match{ - m: make(map[interface{}]chan []byte), - } + return &Match{m: make(map[any]chan<- []byte)} } // Match is a distributed subtype that handles the matching of requests and @@ -17,64 +17,49 @@ func NewMatch() *Match { // connections. Stream systems fan in all incoming payloads to one routine for // processing. type Match struct { - m map[interface{}]chan []byte + m map[any]chan<- []byte mu sync.Mutex } -// Matcher defines a payload matching return mechanism -type Matcher struct { - C chan []byte - sig interface{} - m *Match -} - // Incoming matches with request, disregarding the returned payload -func (m *Match) Incoming(signature interface{}) bool { +func (m *Match) Incoming(signature any) bool { return m.IncomingWithData(signature, nil) } // IncomingWithData matches with requests and takes in the returned payload, to // be processed outside of a stream processing routine and returns true if a handler was found -func (m *Match) IncomingWithData(signature interface{}, data []byte) bool { +func (m *Match) IncomingWithData(signature any, data []byte) bool { m.mu.Lock() defer m.mu.Unlock() ch, ok := m.m[signature] - if ok { - select { - case ch <- data: - default: - // this shouldn't occur but if it does continue to process as normal - return false - } - return true + if !ok { + return false } - return false + ch <- data + close(ch) + delete(m.m, signature) + return true + } // Set the signature response channel for incoming data -func (m *Match) Set(signature interface{}) (Matcher, error) { - var ch chan []byte +func (m *Match) Set(signature any) (<-chan []byte, error) { m.mu.Lock() + defer m.mu.Unlock() if _, ok := m.m[signature]; ok { - m.mu.Unlock() - return Matcher{}, errors.New("signature collision") + return nil, errSignatureCollision } - // This is buffered so we don't need to wait for receiver. - ch = make(chan []byte, 1) + ch := make(chan []byte, 1) // This is buffered so we don't need to wait for receiver. m.m[signature] = ch - m.mu.Unlock() - - return Matcher{ - C: ch, - sig: signature, - m: m, - }, nil + return ch, nil } -// Cleanup closes underlying channel and deletes signature from map -func (m *Matcher) Cleanup() { - m.m.mu.Lock() - close(m.C) - delete(m.m.m, m.sig) - m.m.mu.Unlock() +// Timeout the signature response channel +func (m *Match) Timeout(signature any) { + m.mu.Lock() + defer m.mu.Unlock() + if ch, ok := m.m[signature]; ok { + close(ch) + delete(m.m, signature) + } } diff --git a/exchanges/stream/stream_match_test.go b/exchanges/stream/stream_match_test.go index 2659122535f..aa7e4e4ec62 100644 --- a/exchanges/stream/stream_match_test.go +++ b/exchanges/stream/stream_match_test.go @@ -18,7 +18,7 @@ func TestMatch(t *testing.T) { t.Fatal("should not be able to match") } - m, err := nm.Set("hello") + ch, err := nm.Set("hello") if err != nil { t.Fatal(err) } @@ -28,10 +28,6 @@ func TestMatch(t *testing.T) { t.Fatal("error cannot be nil as this collision cannot occur") } - if m.sig != "hello" { - t.Fatal(err) - } - // try and match with initial payload if !nm.Incoming("hello") { t.Fatal("should of matched") @@ -42,9 +38,7 @@ func TestMatch(t *testing.T) { fmt.Println("should not have been able to match") } - if data := <-m.C; data != nil { + if data := <-ch; data != nil { t.Fatal("data chan should be nil") } - - m.Cleanup() } From dd94e4eb927de54c2e3bf5845908e98f458c3d14 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 17 Jul 2024 15:54:31 +1000 Subject: [PATCH 017/138] update tests --- exchanges/bitfinex/bitfinex_test.go | 7 ++-- exchanges/stream/stream_match_test.go | 41 ++++++++++-------------- exchanges/stream/websocket_connection.go | 20 ++++++------ 3 files changed, 29 insertions(+), 39 deletions(-) diff --git a/exchanges/bitfinex/bitfinex_test.go b/exchanges/bitfinex/bitfinex_test.go index 30a2c85b35f..86a1cfa7293 100644 --- a/exchanges/bitfinex/bitfinex_test.go +++ b/exchanges/bitfinex/bitfinex_test.go @@ -1316,7 +1316,7 @@ func TestWsCancelOffer(t *testing.T) { } func TestWsSubscribedResponse(t *testing.T) { - m, err := b.Websocket.Match.Set("subscribe:waiter1") + ch, err := b.Websocket.Match.Set("subscribe:waiter1") assert.NoError(t, err, "Setting a matcher should not error") err = b.wsHandleData([]byte(`{"event":"subscribed","channel":"ticker","chanId":224555,"subId":"waiter1","symbol":"tBTCUSD","pair":"BTCUSD"}`)) if assert.Error(t, err, "Should error if sub is not registered yet") { @@ -1329,13 +1329,12 @@ func TestWsSubscribedResponse(t *testing.T) { require.NoError(t, err, "AddSubscriptions must not error") err = b.wsHandleData([]byte(`{"event":"subscribed","channel":"ticker","chanId":224555,"subId":"waiter1","symbol":"tBTCUSD","pair":"BTCUSD"}`)) assert.NoError(t, err, "wsHandleData should not error") - if assert.NotEmpty(t, m.C, "Matcher should have received a sub notification") { - msg := <-m.C + if assert.NotEmpty(t, ch, "Matcher should have received a sub notification") { + msg := <-ch cID, err := jsonparser.GetInt(msg, "chanId") assert.NoError(t, err, "Should get chanId from sub notification without error") assert.EqualValues(t, 224555, cID, "Should get the correct chanId through the matcher notification") } - m.Cleanup() } func TestWsOrderBook(t *testing.T) { diff --git a/exchanges/stream/stream_match_test.go b/exchanges/stream/stream_match_test.go index aa7e4e4ec62..da8b145e1f9 100644 --- a/exchanges/stream/stream_match_test.go +++ b/exchanges/stream/stream_match_test.go @@ -1,44 +1,37 @@ package stream import ( - "fmt" "testing" + + "github.com/stretchr/testify/require" ) func TestMatch(t *testing.T) { t.Parallel() - bm := &Match{} - if bm.Incoming("wow") { - t.Fatal("Should not have matched") - } - nm := NewMatch() + require.False(t, nm.Incoming("wow")) + // try to match with unset signature - if nm.Incoming("hello") { - t.Fatal("should not be able to match") - } + require.False(t, nm.Incoming("hello")) ch, err := nm.Set("hello") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) _, err = nm.Set("hello") - if err == nil { - t.Fatal("error cannot be nil as this collision cannot occur") - } + require.ErrorIs(t, err, errSignatureCollision) // try and match with initial payload - if !nm.Incoming("hello") { - t.Fatal("should of matched") - } + require.True(t, nm.Incoming("hello")) + require.Nil(t, <-ch) // put in secondary payload with conflicting signature - if nm.Incoming("hello") { - fmt.Println("should not have been able to match") - } + require.False(t, nm.Incoming("hello")) + + ch, err = nm.Set("hello") + require.NoError(t, err) + + expected := []byte("payload") + require.True(t, nm.IncomingWithData("hello", expected)) - if data := <-ch; data != nil { - t.Fatal("data chan should be nil") - } + require.Equal(t, expected, <-ch) } diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 6a00d01ab74..7c46183e3e1 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -22,34 +22,32 @@ import ( // SendMessageReturnResponse will send a WS message to the connection and wait // for response func (w *WebsocketConnection) SendMessageReturnResponse(signature, request interface{}) ([]byte, error) { - m, err := w.Match.Set(signature) + outbound, err := json.Marshal(request) if err != nil { - return nil, err + return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err) } - defer m.Cleanup() - b, err := json.Marshal(request) + ch, err := w.Match.Set(signature) if err != nil { - return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err) + return nil, err } start := time.Now() - err = w.SendRawMessage(websocket.TextMessage, b) + err = w.SendRawMessage(websocket.TextMessage, outbound) if err != nil { return nil, err } timer := time.NewTimer(w.ResponseMaxLimit) - select { - case payload := <-m.C: + case payload := <-ch: + timer.Stop() if w.Reporter != nil { - w.Reporter.Latency(w.ExchangeName, b, time.Since(start)) + w.Reporter.Latency(w.ExchangeName, payload, time.Since(start)) } - return payload, nil case <-timer.C: - timer.Stop() + w.Match.Timeout(signature) return nil, fmt.Errorf("%s websocket connection: timeout waiting for response with signature: %v", w.ExchangeName, signature) } } From ca597f32edb8c621b19cc1c8e9c290cf5f41a707 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 17 Jul 2024 16:02:10 +1000 Subject: [PATCH 018/138] linter: fix --- exchanges/stream/stream_match.go | 1 - 1 file changed, 1 deletion(-) diff --git a/exchanges/stream/stream_match.go b/exchanges/stream/stream_match.go index 412c08bb72b..edf6d25bfa9 100644 --- a/exchanges/stream/stream_match.go +++ b/exchanges/stream/stream_match.go @@ -39,7 +39,6 @@ func (m *Match) IncomingWithData(signature any, data []byte) bool { close(ch) delete(m.m, signature) return true - } // Set the signature response channel for incoming data From 2e0f0ae9c4466270587a4e1a75b42216e6cec739 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 19 Jul 2024 09:21:56 +1000 Subject: [PATCH 019/138] glorious: nits + handle context cancellations --- exchanges/binance/binance_websocket.go | 2 +- exchanges/bitfinex/bitfinex_websocket.go | 12 +++++------ exchanges/bybit/bybit_websocket.go | 6 +++--- exchanges/coinut/coinut_websocket.go | 20 +++++++++---------- exchanges/deribit/deribit_websocket.go | 4 ++-- exchanges/deribit/deribit_websocket_eps.go | 2 +- exchanges/gateio/gateio_websocket.go | 2 +- .../gateio/gateio_ws_delivery_futures.go | 4 ++-- exchanges/gateio/gateio_ws_futures.go | 4 ++-- exchanges/gateio/gateio_ws_option.go | 2 +- exchanges/hitbtc/hitbtc_websocket.go | 16 +++++++-------- exchanges/huobi/huobi_websocket.go | 6 +++--- exchanges/kraken/kraken_websocket.go | 14 ++++++------- exchanges/kucoin/kucoin_websocket.go | 2 +- exchanges/okcoin/okcoin_websocket.go | 2 +- exchanges/okcoin/okcoin_ws_trade.go | 5 +++-- exchanges/stream/stream_match.go | 4 ++-- exchanges/stream/stream_match_test.go | 5 +++++ exchanges/stream/stream_types.go | 3 ++- exchanges/stream/websocket_connection.go | 15 ++++++++++---- exchanges/stream/websocket_test.go | 15 ++++++++++++-- 21 files changed, 85 insertions(+), 60 deletions(-) diff --git a/exchanges/binance/binance_websocket.go b/exchanges/binance/binance_websocket.go index 1168dc664a3..d8f71764777 100644 --- a/exchanges/binance/binance_websocket.go +++ b/exchanges/binance/binance_websocket.go @@ -579,7 +579,7 @@ func (b *Binance) manageSubs(op string, subs subscription.List) error { Params: subs.QualifiedChannels(), } - respRaw, err := b.Websocket.Conn.SendMessageReturnResponse(req.ID, req) + respRaw, err := b.Websocket.Conn.SendMessageReturnResponse(context.TODO(), req.ID, req) if err == nil { if v, d, _, rErr := jsonparser.Get(respRaw, "result"); rErr != nil { err = rErr diff --git a/exchanges/bitfinex/bitfinex_websocket.go b/exchanges/bitfinex/bitfinex_websocket.go index ed04da9988f..fe7e7349a1f 100644 --- a/exchanges/bitfinex/bitfinex_websocket.go +++ b/exchanges/bitfinex/bitfinex_websocket.go @@ -1756,7 +1756,7 @@ func (b *Bitfinex) subscribeToChan(chans subscription.List) error { _ = b.Websocket.RemoveSubscriptions(c) }() - respRaw, err := b.Websocket.Conn.SendMessageReturnResponse("subscribe:"+subID, req) + respRaw, err := b.Websocket.Conn.SendMessageReturnResponse(context.TODO(), "subscribe:"+subID, req) if err != nil { return fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pairs) } @@ -1849,7 +1849,7 @@ func (b *Bitfinex) unsubscribeFromChan(chans subscription.List) error { "chanId": chanID, } - respRaw, err := b.Websocket.Conn.SendMessageReturnResponse("unsubscribe:"+strconv.Itoa(chanID), req) + respRaw, err := b.Websocket.Conn.SendMessageReturnResponse(context.TODO(), "unsubscribe:"+strconv.Itoa(chanID), req) if err != nil { return err } @@ -1926,7 +1926,7 @@ func (b *Bitfinex) WsSendAuth(ctx context.Context) error { func (b *Bitfinex) WsNewOrder(data *WsNewOrderRequest) (string, error) { data.CustomID = b.Websocket.AuthConn.GenerateMessageID(false) request := makeRequestInterface(wsOrderNew, data) - resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(data.CustomID, request) + resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), data.CustomID, request) if err != nil { return "", err } @@ -1983,7 +1983,7 @@ func (b *Bitfinex) WsNewOrder(data *WsNewOrderRequest) (string, error) { // WsModifyOrder authenticated modify order request func (b *Bitfinex) WsModifyOrder(data *WsUpdateOrderRequest) error { request := makeRequestInterface(wsOrderUpdate, data) - resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(data.OrderID, request) + resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), data.OrderID, request) if err != nil { return err } @@ -2037,7 +2037,7 @@ func (b *Bitfinex) WsCancelOrder(orderID int64) error { OrderID: orderID, } request := makeRequestInterface(wsOrderCancel, cancel) - resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(orderID, request) + resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), orderID, request) if err != nil { return err } @@ -2094,7 +2094,7 @@ func (b *Bitfinex) WsCancelOffer(orderID int64) error { OrderID: orderID, } request := makeRequestInterface(wsFundingOfferCancel, cancel) - resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(orderID, request) + resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), orderID, request) if err != nil { return err } diff --git a/exchanges/bybit/bybit_websocket.go b/exchanges/bybit/bybit_websocket.go index 16ac416ce94..9e6308f2e43 100644 --- a/exchanges/bybit/bybit_websocket.go +++ b/exchanges/bybit/bybit_websocket.go @@ -118,7 +118,7 @@ func (by *Bybit) WsAuth(ctx context.Context) error { Operation: "auth", Args: []interface{}{creds.Key, intNonce, sign}, } - resp, err := by.Websocket.AuthConn.SendMessageReturnResponse(req.RequestID, req) + resp, err := by.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), req.RequestID, req) if err != nil { return err } @@ -220,12 +220,12 @@ func (by *Bybit) handleSpotSubscription(operation string, channelsToSubscribe su for a := range payloads { var response []byte if payloads[a].auth { - response, err = by.Websocket.AuthConn.SendMessageReturnResponse(payloads[a].RequestID, payloads[a]) + response, err = by.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), payloads[a].RequestID, payloads[a]) if err != nil { return err } } else { - response, err = by.Websocket.Conn.SendMessageReturnResponse(payloads[a].RequestID, payloads[a]) + response, err = by.Websocket.Conn.SendMessageReturnResponse(context.TODO(), payloads[a].RequestID, payloads[a]) if err != nil { return err } diff --git a/exchanges/coinut/coinut_websocket.go b/exchanges/coinut/coinut_websocket.go index b87792aede2..ec74b7d2f5b 100644 --- a/exchanges/coinut/coinut_websocket.go +++ b/exchanges/coinut/coinut_websocket.go @@ -480,7 +480,7 @@ func (c *COINUT) WsGetInstruments() (Instruments, error) { SecurityType: strings.ToUpper(asset.Spot.String()), Nonce: getNonce(), } - resp, err := c.Websocket.Conn.SendMessageReturnResponse(request.Nonce, request) + resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.Nonce, request) if err != nil { return list, err } @@ -648,7 +648,7 @@ func (c *COINUT) Unsubscribe(channelToUnsubscribe subscription.List) error { Subscribe: false, Nonce: getNonce(), } - resp, err := c.Websocket.Conn.SendMessageReturnResponse(subscribe.Nonce, subscribe) + resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), subscribe.Nonce, subscribe) if err != nil { errs = common.AppendError(errs, err) continue @@ -691,7 +691,7 @@ func (c *COINUT) wsAuthenticate(ctx context.Context) error { } r.Hmac = crypto.HexEncodeToString(hmac) - resp, err := c.Websocket.Conn.SendMessageReturnResponse(r.Nonce, r) + resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), r.Nonce, r) if err != nil { return err } @@ -714,7 +714,7 @@ func (c *COINUT) wsGetAccountBalance() (*UserBalance, error) { Request: "user_balance", Nonce: getNonce(), } - resp, err := c.Websocket.Conn.SendMessageReturnResponse(accBalance.Nonce, accBalance) + resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), accBalance.Nonce, accBalance) if err != nil { return nil, err } @@ -750,7 +750,7 @@ func (c *COINUT) wsSubmitOrder(o *WsSubmitOrderParameters) (*order.Detail, error if o.OrderID > 0 { orderSubmissionRequest.OrderID = o.OrderID } - resp, err := c.Websocket.Conn.SendMessageReturnResponse(orderSubmissionRequest.Nonce, orderSubmissionRequest) + resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), orderSubmissionRequest.Nonce, orderSubmissionRequest) if err != nil { return nil, err } @@ -793,7 +793,7 @@ func (c *COINUT) wsSubmitOrders(orders []WsSubmitOrderParameters) ([]order.Detai orderRequest.Nonce = getNonce() orderRequest.Request = "new_orders" - resp, err := c.Websocket.Conn.SendMessageReturnResponse(orderRequest.Nonce, orderRequest) + resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), orderRequest.Nonce, orderRequest) if err != nil { errs = append(errs, err) return nil, errs @@ -829,7 +829,7 @@ func (c *COINUT) wsGetOpenOrders(curr string) (*WsUserOpenOrdersResponse, error) openOrdersRequest.Nonce = getNonce() openOrdersRequest.InstrumentID = c.instrumentMap.LookupID(curr) - resp, err := c.Websocket.Conn.SendMessageReturnResponse(openOrdersRequest.Nonce, openOrdersRequest) + resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), openOrdersRequest.Nonce, openOrdersRequest) if err != nil { return response, err } @@ -862,7 +862,7 @@ func (c *COINUT) wsCancelOrder(cancellation *WsCancelOrderParameters) (*CancelOr cancellationRequest.OrderID = cancellation.OrderID cancellationRequest.Nonce = getNonce() - resp, err := c.Websocket.Conn.SendMessageReturnResponse(cancellationRequest.Nonce, cancellationRequest) + resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), cancellationRequest.Nonce, cancellationRequest) if err != nil { return response, err } @@ -903,7 +903,7 @@ func (c *COINUT) wsCancelOrders(cancellations []WsCancelOrderParameters) (*Cance cancelOrderRequest.Request = "cancel_orders" cancelOrderRequest.Nonce = getNonce() - resp, err := c.Websocket.Conn.SendMessageReturnResponse(cancelOrderRequest.Nonce, cancelOrderRequest) + resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), cancelOrderRequest.Nonce, cancelOrderRequest) if err != nil { return response, err } @@ -933,7 +933,7 @@ func (c *COINUT) wsGetTradeHistory(p currency.Pair, start, limit int64) (*WsTrad request.Start = start request.Limit = limit - resp, err := c.Websocket.Conn.SendMessageReturnResponse(request.Nonce, request) + resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.Nonce, request) if err != nil { return response, err } diff --git a/exchanges/deribit/deribit_websocket.go b/exchanges/deribit/deribit_websocket.go index 29b4163d36c..9c096022c2c 100644 --- a/exchanges/deribit/deribit_websocket.go +++ b/exchanges/deribit/deribit_websocket.go @@ -147,7 +147,7 @@ func (d *Deribit) wsLogin(ctx context.Context) error { "signature": crypto.HexEncodeToString(hmac), }, } - resp, err := d.Websocket.Conn.SendMessageReturnResponse(request.ID, request) + resp, err := d.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request) if err != nil { d.Websocket.SetCanUseAuthenticatedEndpoints(false) return err @@ -1165,7 +1165,7 @@ func (d *Deribit) handleSubscription(operation string, channels subscription.Lis return err } for x := range payloads { - data, err := d.Websocket.Conn.SendMessageReturnResponse(payloads[x].ID, payloads[x]) + data, err := d.Websocket.Conn.SendMessageReturnResponse(context.TODO(), payloads[x].ID, payloads[x]) if err != nil { return err } diff --git a/exchanges/deribit/deribit_websocket_eps.go b/exchanges/deribit/deribit_websocket_eps.go index 19f3010c7e2..4c639cc60e4 100644 --- a/exchanges/deribit/deribit_websocket_eps.go +++ b/exchanges/deribit/deribit_websocket_eps.go @@ -2406,7 +2406,7 @@ func (d *Deribit) sendWsPayload(ep request.EndpointLimit, input *WsRequest, resp log.Debugf(log.RequestSys, "%s attempt %d", d.Name, attempt) } var payload []byte - payload, err = d.Websocket.Conn.SendMessageReturnResponse(input.ID, input) + payload, err = d.Websocket.Conn.SendMessageReturnResponse(context.TODO(), input.ID, input) if err != nil { return err } diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index a5b7b43f71c..26c1c23fb8d 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -697,7 +697,7 @@ func (g *Gateio) handleSubscription(event string, channelsToSubscribe subscripti } var errs error for k := range payloads { - result, err := g.Websocket.Conn.SendMessageReturnResponse(payloads[k].ID, payloads[k]) + result, err := g.Websocket.Conn.SendMessageReturnResponse(context.TODO(), payloads[k].ID, payloads[k]) if err != nil { errs = common.AppendError(errs, err) continue diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_ws_delivery_futures.go index b9242981033..1a41e21a883 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_ws_delivery_futures.go @@ -207,9 +207,9 @@ func (g *Gateio) handleDeliveryFuturesSubscription(event string, channelsToSubsc for con, val := range payloads { for k := range val { if con == 0 { - respByte, err = g.Websocket.Conn.SendMessageReturnResponse(val[k].ID, val[k]) + respByte, err = g.Websocket.Conn.SendMessageReturnResponse(context.TODO(), val[k].ID, val[k]) } else { - respByte, err = g.Websocket.AuthConn.SendMessageReturnResponse(val[k].ID, val[k]) + respByte, err = g.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), val[k].ID, val[k]) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index 8962212d0ca..749e8e0a839 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -287,9 +287,9 @@ func (g *Gateio) handleFuturesSubscription(event string, channelsToSubscribe sub for con, val := range payloads { for k := range val { if con == 0 { - respByte, err = g.Websocket.Conn.SendMessageReturnResponse(val[k].ID, val[k]) + respByte, err = g.Websocket.Conn.SendMessageReturnResponse(context.TODO(), val[k].ID, val[k]) } else { - respByte, err = g.Websocket.AuthConn.SendMessageReturnResponse(val[k].ID, val[k]) + respByte, err = g.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), val[k].ID, val[k]) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_ws_option.go index fe1384c8e52..a573a0df1a3 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_ws_option.go @@ -319,7 +319,7 @@ func (g *Gateio) handleOptionsSubscription(event string, channelsToSubscribe sub } var errs error for k := range payloads { - result, err := g.Websocket.Conn.SendMessageReturnResponse(payloads[k].ID, payloads[k]) + result, err := g.Websocket.Conn.SendMessageReturnResponse(context.TODO(), payloads[k].ID, payloads[k]) if err != nil { errs = common.AppendError(errs, err) continue diff --git a/exchanges/hitbtc/hitbtc_websocket.go b/exchanges/hitbtc/hitbtc_websocket.go index c361daaa2f5..68f68de93be 100644 --- a/exchanges/hitbtc/hitbtc_websocket.go +++ b/exchanges/hitbtc/hitbtc_websocket.go @@ -632,7 +632,7 @@ func (h *HitBTC) wsPlaceOrder(pair currency.Pair, side string, price, quantity f }, ID: id, } - resp, err := h.Websocket.Conn.SendMessageReturnResponse(id, request) + resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), id, request) if err != nil { return nil, fmt.Errorf("%v %v", h.Name, err) } @@ -659,7 +659,7 @@ func (h *HitBTC) wsCancelOrder(clientOrderID string) (*WsCancelOrderResponse, er }, ID: h.Websocket.Conn.GenerateMessageID(false), } - resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request) + resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request) if err != nil { return nil, fmt.Errorf("%v %v", h.Name, err) } @@ -689,7 +689,7 @@ func (h *HitBTC) wsReplaceOrder(clientOrderID string, quantity, price float64) ( }, ID: h.Websocket.Conn.GenerateMessageID(false), } - resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request) + resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request) if err != nil { return nil, fmt.Errorf("%v %v", h.Name, err) } @@ -714,7 +714,7 @@ func (h *HitBTC) wsGetActiveOrders() (*wsActiveOrdersResponse, error) { Params: WsReplaceOrderRequestData{}, ID: h.Websocket.Conn.GenerateMessageID(false), } - resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request) + resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request) if err != nil { return nil, fmt.Errorf("%v %v", h.Name, err) } @@ -739,7 +739,7 @@ func (h *HitBTC) wsGetTradingBalance() (*WsGetTradingBalanceResponse, error) { Params: WsReplaceOrderRequestData{}, ID: h.Websocket.Conn.GenerateMessageID(false), } - resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request) + resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request) if err != nil { return nil, fmt.Errorf("%v %v", h.Name, err) } @@ -763,7 +763,7 @@ func (h *HitBTC) wsGetCurrencies(currencyItem currency.Code) (*WsGetCurrenciesRe }, ID: h.Websocket.Conn.GenerateMessageID(false), } - resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request) + resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request) if err != nil { return nil, fmt.Errorf("%v %v", h.Name, err) } @@ -792,7 +792,7 @@ func (h *HitBTC) wsGetSymbols(c currency.Pair) (*WsGetSymbolsResponse, error) { }, ID: h.Websocket.Conn.GenerateMessageID(false), } - resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request) + resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request) if err != nil { return nil, fmt.Errorf("%v %v", h.Name, err) } @@ -824,7 +824,7 @@ func (h *HitBTC) wsGetTrades(c currency.Pair, limit int64, sort, by string) (*Ws }, ID: h.Websocket.Conn.GenerateMessageID(false), } - resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request) + resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request) if err != nil { return nil, fmt.Errorf("%v %v", h.Name, err) } diff --git a/exchanges/huobi/huobi_websocket.go b/exchanges/huobi/huobi_websocket.go index 4a4e4d7adf7..3c86ac62e1e 100644 --- a/exchanges/huobi/huobi_websocket.go +++ b/exchanges/huobi/huobi_websocket.go @@ -700,7 +700,7 @@ func (h *HUOBI) wsGetAccountsList(ctx context.Context) (*WsAuthenticatedAccounts } request.Signature = crypto.Base64Encode(hmac) request.ClientID = h.Websocket.AuthConn.GenerateMessageID(true) - resp, err := h.Websocket.AuthConn.SendMessageReturnResponse(request.ClientID, request) + resp, err := h.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), request.ClientID, request) if err != nil { return nil, err } @@ -752,7 +752,7 @@ func (h *HUOBI) wsGetOrdersList(ctx context.Context, accountID int64, pair curre request.Signature = crypto.Base64Encode(hmac) request.ClientID = h.Websocket.AuthConn.GenerateMessageID(true) - resp, err := h.Websocket.AuthConn.SendMessageReturnResponse(request.ClientID, request) + resp, err := h.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), request.ClientID, request) if err != nil { return nil, err } @@ -794,7 +794,7 @@ func (h *HUOBI) wsGetOrderDetails(ctx context.Context, orderID string) (*WsAuthe } request.Signature = crypto.Base64Encode(hmac) request.ClientID = h.Websocket.AuthConn.GenerateMessageID(true) - resp, err := h.Websocket.AuthConn.SendMessageReturnResponse(request.ClientID, request) + resp, err := h.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), request.ClientID, request) if err != nil { return nil, err } diff --git a/exchanges/kraken/kraken_websocket.go b/exchanges/kraken/kraken_websocket.go index eedccc8ff49..4d110e75c26 100644 --- a/exchanges/kraken/kraken_websocket.go +++ b/exchanges/kraken/kraken_websocket.go @@ -1230,9 +1230,9 @@ channels: for i := range *subs { var err error if common.StringDataContains(authenticatedChannels, (*subs)[i].Subscription.Name) { - _, err = k.Websocket.AuthConn.SendMessageReturnResponse((*subs)[i].RequestID, (*subs)[i]) + _, err = k.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), (*subs)[i].RequestID, (*subs)[i]) } else { - _, err = k.Websocket.Conn.SendMessageReturnResponse((*subs)[i].RequestID, (*subs)[i]) + _, err = k.Websocket.Conn.SendMessageReturnResponse(context.TODO(), (*subs)[i].RequestID, (*subs)[i]) } if err == nil { err = k.Websocket.AddSuccessfulSubscriptions((*subs)[i].Channels...) @@ -1289,9 +1289,9 @@ channels: for i := range unsubs { var err error if common.StringDataContains(authenticatedChannels, unsubs[i].Subscription.Name) { - _, err = k.Websocket.AuthConn.SendMessageReturnResponse(unsubs[i].RequestID, unsubs[i]) + _, err = k.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), unsubs[i].RequestID, unsubs[i]) } else { - _, err = k.Websocket.Conn.SendMessageReturnResponse(unsubs[i].RequestID, unsubs[i]) + _, err = k.Websocket.Conn.SendMessageReturnResponse(context.TODO(), unsubs[i].RequestID, unsubs[i]) } if err == nil { err = k.Websocket.RemoveSubscriptions(unsubs[i].Channels...) @@ -1309,7 +1309,7 @@ func (k *Kraken) wsAddOrder(request *WsAddOrderRequest) (string, error) { request.RequestID = id request.Event = krakenWsAddOrder request.Token = authToken - jsonResp, err := k.Websocket.AuthConn.SendMessageReturnResponse(id, request) + jsonResp, err := k.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), id, request) if err != nil { return "", err } @@ -1348,7 +1348,7 @@ func (k *Kraken) wsCancelOrder(orderID string) error { RequestID: id, } - resp, err := k.Websocket.AuthConn.SendMessageReturnResponse(id, request) + resp, err := k.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), id, request) if err != nil { return fmt.Errorf("%w %s: %w", errCancellingOrder, orderID, err) } @@ -1378,7 +1378,7 @@ func (k *Kraken) wsCancelAllOrders() (*WsCancelOrderResponse, error) { RequestID: id, } - jsonResp, err := k.Websocket.AuthConn.SendMessageReturnResponse(id, request) + jsonResp, err := k.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), id, request) if err != nil { return &WsCancelOrderResponse{}, err } diff --git a/exchanges/kucoin/kucoin_websocket.go b/exchanges/kucoin/kucoin_websocket.go index 5b7c14bde48..ac9225857ac 100644 --- a/exchanges/kucoin/kucoin_websocket.go +++ b/exchanges/kucoin/kucoin_websocket.go @@ -996,7 +996,7 @@ func (ku *Kucoin) manageSubscriptions(subs subscription.List, operation string) PrivateChannel: s.Authenticated, Response: true, } - if respRaw, err := ku.Websocket.Conn.SendMessageReturnResponse("msgID:"+msgID, req); err != nil { + if respRaw, err := ku.Websocket.Conn.SendMessageReturnResponse(context.TODO(), "msgID:"+msgID, req); err != nil { errs = common.AppendError(errs, err) } else { rType, err := jsonparser.GetUnsafeString(respRaw, "type") diff --git a/exchanges/okcoin/okcoin_websocket.go b/exchanges/okcoin/okcoin_websocket.go index b5af0f29b83..5ab48cb86f6 100644 --- a/exchanges/okcoin/okcoin_websocket.go +++ b/exchanges/okcoin/okcoin_websocket.go @@ -147,7 +147,7 @@ func (o *Okcoin) WsLogin(ctx context.Context, dialer *websocket.Dialer) error { }, }, } - _, err = o.Websocket.AuthConn.SendMessageReturnResponse("login", authRequest) + _, err = o.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), "login", authRequest) if err != nil { return err } diff --git a/exchanges/okcoin/okcoin_ws_trade.go b/exchanges/okcoin/okcoin_ws_trade.go index 85b6102e24d..054396ceb71 100644 --- a/exchanges/okcoin/okcoin_ws_trade.go +++ b/exchanges/okcoin/okcoin_ws_trade.go @@ -1,6 +1,7 @@ package okcoin import ( + "context" "encoding/json" "errors" "fmt" @@ -153,9 +154,9 @@ func (o *Okcoin) SendWebsocketRequest(operation string, data, result interface{} var err error // TODO: ratelimits for websocket if authenticated { - byteData, err = o.Websocket.AuthConn.SendMessageReturnResponse(req.ID, req) + byteData, err = o.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), req.ID, req) } else { - byteData, err = o.Websocket.Conn.SendMessageReturnResponse(req.ID, req) + byteData, err = o.Websocket.Conn.SendMessageReturnResponse(context.TODO(), req.ID, req) } if err != nil { return err diff --git a/exchanges/stream/stream_match.go b/exchanges/stream/stream_match.go index edf6d25bfa9..5a024fc3f49 100644 --- a/exchanges/stream/stream_match.go +++ b/exchanges/stream/stream_match.go @@ -53,8 +53,8 @@ func (m *Match) Set(signature any) (<-chan []byte, error) { return ch, nil } -// Timeout the signature response channel -func (m *Match) Timeout(signature any) { +// RemoveSignature removes the signature response from map and closes the channel. +func (m *Match) RemoveSignature(signature any) { m.mu.Lock() defer m.mu.Unlock() if ch, ok := m.m[signature]; ok { diff --git a/exchanges/stream/stream_match_test.go b/exchanges/stream/stream_match_test.go index da8b145e1f9..3530c12e686 100644 --- a/exchanges/stream/stream_match_test.go +++ b/exchanges/stream/stream_match_test.go @@ -34,4 +34,9 @@ func TestMatch(t *testing.T) { require.True(t, nm.IncomingWithData("hello", expected)) require.Equal(t, expected, <-ch) + + _, err = nm.Set("purge me") + require.NoError(t, err) + nm.RemoveSignature("purge me") + require.False(t, nm.IncomingWithData("purge me", expected)) } diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index e342c74dace..578beb2f0d8 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -1,6 +1,7 @@ package stream import ( + "context" "net/http" "time" @@ -17,7 +18,7 @@ type Connection interface { SendJSONMessage(interface{}) error SetupPingHandler(PingHandler) GenerateMessageID(highPrecision bool) int64 - SendMessageReturnResponse(signature interface{}, request interface{}) ([]byte, error) + SendMessageReturnResponse(ctx context.Context, signature interface{}, request interface{}) ([]byte, error) SendRawMessage(messageType int, message []byte) error SetURL(string) SetProxy(string) diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 7c46183e3e1..de62442b3ef 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -4,8 +4,10 @@ import ( "bytes" "compress/flate" "compress/gzip" + "context" "crypto/rand" "encoding/json" + "errors" "fmt" "io" "math/big" @@ -19,9 +21,11 @@ import ( "github.com/thrasher-corp/gocryptotrader/log" ) +var errMatchTimeout = errors.New("websocket connection: timeout waiting for response with signature") + // SendMessageReturnResponse will send a WS message to the connection and wait // for response -func (w *WebsocketConnection) SendMessageReturnResponse(signature, request interface{}) ([]byte, error) { +func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, signature, request interface{}) ([]byte, error) { outbound, err := json.Marshal(request) if err != nil { return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err) @@ -43,12 +47,15 @@ func (w *WebsocketConnection) SendMessageReturnResponse(signature, request inter case payload := <-ch: timer.Stop() if w.Reporter != nil { - w.Reporter.Latency(w.ExchangeName, payload, time.Since(start)) + w.Reporter.Latency(w.ExchangeName, outbound, time.Since(start)) } return payload, nil case <-timer.C: - w.Match.Timeout(signature) - return nil, fmt.Errorf("%s websocket connection: timeout waiting for response with signature: %v", w.ExchangeName, signature) + w.Match.RemoveSignature(signature) + return nil, fmt.Errorf("%s %w: %v", w.ExchangeName, errMatchTimeout, signature) + case <-ctx.Done(): + w.Match.RemoveSignature(signature) + return nil, ctx.Err() } } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 2d44de95097..ca7e72429dd 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -4,6 +4,7 @@ import ( "bytes" "compress/flate" "compress/gzip" + "context" "encoding/json" "errors" "fmt" @@ -753,10 +754,20 @@ func TestSendMessageWithResponse(t *testing.T) { RequestID: wc.GenerateMessageID(false), } - _, err = wc.SendMessageReturnResponse(request.RequestID, request) + _, err = wc.SendMessageReturnResponse(context.Background(), request.RequestID, request) if err != nil { t.Error(err) } + + cancelledCtx, fn := context.WithDeadline(context.Background(), time.Now()) + fn() + _, err = wc.SendMessageReturnResponse(cancelledCtx, "123", request) + assert.ErrorIs(t, err, context.DeadlineExceeded) + + // with timeout + wc.ResponseMaxLimit = 1 + _, err = wc.SendMessageReturnResponse(context.Background(), "123", request) + assert.ErrorIs(t, err, errMatchTimeout, "SendMessageReturnResponse should error when request ID not found") } type reporter struct { @@ -1182,7 +1193,7 @@ func TestLatency(t *testing.T) { RequestID: wc.GenerateMessageID(false), } - _, err = wc.SendMessageReturnResponse(request.RequestID, request) + _, err = wc.SendMessageReturnResponse(context.Background(), request.RequestID, request) if err != nil { t.Error(err) } From 2d2f8722dd81fa13724dbec1d348777206b96cf4 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 19 Jul 2024 13:32:42 +1000 Subject: [PATCH 020/138] stop ping handler routine leak --- exchanges/stream/websocket_connection.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 2157dd9a01a..d7e4073e0fb 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -188,8 +188,8 @@ func (w *WebsocketConnection) SetupPingHandler(handler PingHandler) { return } w.Wg.Add(1) - defer w.Wg.Done() go func() { + defer w.Wg.Done() ticker := time.NewTicker(handler.Delay) for { select { From 39191c865734207b8a836819d9f36825b20d9efa Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 19 Jul 2024 13:47:01 +1000 Subject: [PATCH 021/138] * Fix bug where reader routine on error that is not a disconnection error but websocket frame error or anything really makes the reader routine return and then connection never cycles and the buffer gets filled. * Handle reconnection via an errors.Is check which is simpler and in that scope allow for quick disconnect reconnect without waiting for connection cycle. * Dial now uses code from DialContext but just calls context.Background() * Don't allow reader to return on parse binary response error. Just output error and return a non nil response --- exchanges/stream/websocket.go | 10 ++- exchanges/stream/websocket_connection.go | 80 ++++++++---------------- 2 files changed, 33 insertions(+), 57 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index a68b56ed79f..dda622fa324 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -532,8 +532,7 @@ func (w *Websocket) connectionMonitor() error { } select { case err := <-w.ReadMessageErrors: - w.DataHandler <- err - if IsDisconnectionError(err) { + if errors.Is(err, errConnectionFault) { log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err) if w.IsConnected() { if shutdownErr := w.Shutdown(); shutdownErr != nil { @@ -541,6 +540,13 @@ func (w *Websocket) connectionMonitor() error { } } } + // Speedier reconnection, instead of waiting for the next cycle. + if w.IsEnabled() && (!w.IsConnected() && !w.IsConnecting()) { + if err := w.Connect(); err != nil { + log.Errorln(log.WebsocketMgr, err) + } + } + w.DataHandler <- err // hand over the error to the data handler (shutdown and reconnection is priority) case <-timer.C: if !w.IsConnecting() && !w.IsConnected() { err := w.Connect() diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index d7e4073e0fb..9df8749a2d3 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -7,6 +7,7 @@ import ( "context" "crypto/rand" "encoding/json" + "errors" "fmt" "io" "math/big" @@ -20,6 +21,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/log" ) +// errConnectionFault is a connection fault error which alerts the system that a +// connection cycle needs to take place. +var errConnectionFault = errors.New("connection fault") + // SendMessageReturnResponse will send a WS message to the connection and wait // for response func (w *WebsocketConnection) SendMessageReturnResponse(signature, request interface{}) ([]byte, error) { @@ -57,35 +62,7 @@ func (w *WebsocketConnection) SendMessageReturnResponse(signature, request inter // Dial sets proxy urls and then connects to the websocket func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header) error { - if w.ProxyURL != "" { - proxy, err := url.Parse(w.ProxyURL) - if err != nil { - return err - } - dialer.Proxy = http.ProxyURL(proxy) - } - - var err error - var conStatus *http.Response - w.Connection, conStatus, err = dialer.Dial(w.URL, headers) - if err != nil { - if conStatus != nil { - _ = conStatus.Body.Close() - return fmt.Errorf("%s websocket connection: %v %v %v Error: %w", w.ExchangeName, w.URL, conStatus, conStatus.StatusCode, err) - } - return fmt.Errorf("%s websocket connection: %v Error: %w", w.ExchangeName, w.URL, err) - } - _ = conStatus.Body.Close() - - if w.Verbose { - log.Infof(log.WebsocketMgr, "%v Websocket connected to %s\n", w.ExchangeName, w.URL) - } - select { - case w.Traffic <- struct{}{}: - default: - } - w.setConnectedStatus(true) - return nil + return w.DialContext(context.Background(), dialer, headers) } // DialContext sets proxy urls and then connects to the websocket @@ -231,23 +208,22 @@ func (w *WebsocketConnection) ReadMessage() Response { } mType, resp, err := w.Connection.ReadMessage() if err != nil { - if IsDisconnectionError(err) { - if w.setConnectedStatus(false) { - // NOTE: When w.setConnectedStatus() returns true the underlying - // state was changed and this infers that the connection was - // externally closed and an error is reported else Shutdown() - // method on WebsocketConnection type has been called and can - // be skipped. - select { - case w.readMessageErrors <- err: - default: - // bypass if there is no receiver, as this stops it returning - // when shutdown is called. - log.Warnf(log.WebsocketMgr, - "%s failed to relay error: %v", - w.ExchangeName, - err) - } + // Any error condition will return nil response which will return the + // reader routine and the connection will hang with no readers. This has + // to be handed over to w.readMessageErrors if there is an active + // connection. + if w.setConnectedStatus(false) { + // NOTE: When w.setConnectedStatus() returns true the underlying + // state was changed and this infers that the connection was + // externally closed and an error is reported else Shutdown() + // method on WebsocketConnection type has been called and can + // be skipped. + select { + case w.readMessageErrors <- fmt.Errorf("%w: %w", err, errConnectionFault): + default: + // bypass if there is no receiver, as this stops it returning + // when shutdown is called. + log.Warnf(log.WebsocketMgr, "%s failed to relay error: %v", w.ExchangeName, err) } } return Response{} @@ -265,18 +241,12 @@ func (w *WebsocketConnection) ReadMessage() Response { case websocket.BinaryMessage: standardMessage, err = w.parseBinaryResponse(resp) if err != nil { - log.Errorf(log.WebsocketMgr, - "%v websocket connection: parseBinaryResponse error: %v", - w.ExchangeName, - err) - return Response{} + log.Errorf(log.WebsocketMgr, "%v websocket connection: parseBinaryResponse error: %v", w.ExchangeName, err) + return Response{Raw: []byte(``)} // Non-nil response to avoid the reader returning on this case. } } if w.Verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket connection: message received: %v", - w.ExchangeName, - string(standardMessage)) + log.Debugf(log.WebsocketMgr, "%v websocket connection: message received: %v", w.ExchangeName, string(standardMessage)) } return Response{Raw: standardMessage, Type: mType} } From f1c38956da10efe6c829e9edbc04d41d261ab345 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 19 Jul 2024 13:59:10 +1000 Subject: [PATCH 022/138] Allow rollback on connect on any error across all connections --- exchanges/stream/websocket.go | 45 ++++++++++++++++++++---------- exchanges/stream/websocket_test.go | 3 +- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index dda622fa324..0dc05fd87ae 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -321,7 +321,10 @@ func (w *Websocket) Connect() error { w.trafficMonitor() w.setState(connectingState) - if w.connector != nil { + if !w.useMultiConnectionManagement { + if w.connector == nil { + return fmt.Errorf("%v %w", w.exchangeName, errNoConnectFunc) + } err := w.connector() if err != nil { w.setState(disconnectedState) @@ -353,14 +356,14 @@ func (w *Websocket) Connect() error { return fmt.Errorf("cannot connect: %w", errNoPendingConnections) } - // Assume connected state and if there are any issues below can call Shutdown - w.setState(connectedState) + // multiConnectFatalError is a fatal error that will cause all connections to + // be shutdown and the websocket to be disconnected. + var multiConnectFatalError error - var multiConnectError error // TODO: Implement concurrency below. for i := range w.connectionManager { if w.connectionManager[i].Setup.GenerateSubscriptions == nil { - multiConnectError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errWebsocketSubscriptionsGeneratorUnset) + multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errWebsocketSubscriptionsGeneratorUnset) break } @@ -372,7 +375,7 @@ func (w *Websocket) Connect() error { } continue // Non-fatal error, we can continue to the next connection } - multiConnectError = fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) + multiConnectFatalError = fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) break } @@ -385,15 +388,15 @@ func (w *Websocket) Connect() error { } if w.connectionManager[i].Setup.Connector == nil { - multiConnectError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errNoConnectFunc) + multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errNoConnectFunc) break } if w.connectionManager[i].Setup.Handler == nil { - multiConnectError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errWebsocketDataHandlerUnset) + multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errWebsocketDataHandlerUnset) break } if w.connectionManager[i].Setup.Subscriber == nil { - multiConnectError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errWebsocketSubscriberUnset) + multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errWebsocketSubscriberUnset) break } @@ -403,7 +406,7 @@ func (w *Websocket) Connect() error { err = w.connectionManager[i].Setup.Connector(context.TODO(), conn) if err != nil { - multiConnectError = fmt.Errorf("%v Error connecting %w", w.exchangeName, err) + multiConnectFatalError = fmt.Errorf("%v Error connecting %w", w.exchangeName, err) break } @@ -414,14 +417,23 @@ func (w *Websocket) Connect() error { err = w.connectionManager[i].Setup.Subscriber(context.TODO(), conn, subs) if err != nil { - multiConnectError = fmt.Errorf("%v Error subscribing %w", w.exchangeName, err) + multiConnectFatalError = fmt.Errorf("%v Error subscribing %w", w.exchangeName, err) break } + if w.verbose { + log.Debugf(log.WebsocketMgr, "%s websocket: [conn:%d] [URL:%s] connected. [Subscribed: %d]", + w.exchangeName, + i+1, + conn.URL, + len(subs)) + } + w.connectionManager[i].Connection = conn } - if multiConnectError != nil { + if multiConnectFatalError != nil { + // Roll back any successful connections and flush subscriptions for conn, candidate := range w.connections { if err := conn.Shutdown(); err != nil { log.Errorln(log.WebsocketMgr, err) @@ -429,10 +441,15 @@ func (w *Websocket) Connect() error { candidate.Subscriptions.Clear() } clear(w.connections) - w.setState(disconnectedState) - return multiConnectError + w.setState(disconnectedState) // Flip from connecting to disconnected. + return multiConnectFatalError } + // Assume connected state here. All connections have been established. + // All subscriptions have been sent and stored. All data received is being + // handled by the appropriate data handler. + w.setState(connectedState) + if !w.IsConnectionMonitorRunning() { err := w.connectionMonitor() if err != nil { diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index c6b661d5303..d5c8dd72f13 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -353,7 +353,8 @@ func TestConnectionMessageErrors(t *testing.T) { assert.EventuallyWithT(t, c, 2*time.Second, 10*time.Millisecond, "Should get an error down the routine") // Test individual connection defined functions - ws.connector = nil + require.NoError(t, ws.Shutdown()) + ws.useMultiConnectionManagement = true err = ws.Connect() assert.ErrorIs(t, err, errNoPendingConnections, "Connect should error correctly") From 09bff6c4ac06bf7db43e62cfc6b82051de0beaaa Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 19 Jul 2024 14:06:19 +1000 Subject: [PATCH 023/138] fix shadow jutsu --- exchanges/stream/websocket.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 0dc05fd87ae..f4b81ba4c60 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -559,8 +559,8 @@ func (w *Websocket) connectionMonitor() error { } // Speedier reconnection, instead of waiting for the next cycle. if w.IsEnabled() && (!w.IsConnected() && !w.IsConnecting()) { - if err := w.Connect(); err != nil { - log.Errorln(log.WebsocketMgr, err) + if connectErr := w.Connect(); connectErr != nil { + log.Errorln(log.WebsocketMgr, connectErr) } } w.DataHandler <- err // hand over the error to the data handler (shutdown and reconnection is priority) From e66c9be72fa416387d22a784bc13b9c8b335fa3a Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Tue, 23 Jul 2024 09:40:48 +1000 Subject: [PATCH 024/138] glorious/gk: nitters - adds in ws mock server --- exchanges/exchange.go | 4 +- exchanges/stream/stream_types.go | 4 +- exchanges/stream/websocket.go | 29 ++++++----- exchanges/stream/websocket_connection.go | 19 +++++--- exchanges/stream/websocket_test.go | 62 ++++++++++++++++++------ exchanges/stream/websocket_types.go | 6 +-- 6 files changed, 82 insertions(+), 42 deletions(-) diff --git a/exchanges/exchange.go b/exchanges/exchange.go index 4876ea8a989..75a27905c4b 100644 --- a/exchanges/exchange.go +++ b/exchanges/exchange.go @@ -1131,7 +1131,7 @@ func (b *Base) SubscribeToWebsocketChannels(channels subscription.List) error { if b.Websocket == nil { return common.ErrFunctionNotSupported } - return b.Websocket.SubscribeToChannels(nil, channels) + return b.Websocket.SubscribeToChannels(b.Websocket.Conn, channels) } // UnsubscribeToWebsocketChannels removes from ChannelsToSubscribe @@ -1140,7 +1140,7 @@ func (b *Base) UnsubscribeToWebsocketChannels(channels subscription.List) error if b.Websocket == nil { return common.ErrFunctionNotSupported } - return b.Websocket.UnsubscribeChannels(nil, channels) + return b.Websocket.UnsubscribeChannels(b.Websocket.Conn, channels) } // GetSubscriptions returns a copied list of subscriptions diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index f757ff92909..246b729b93b 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -65,10 +65,10 @@ type ConnectionSetup struct { Handler func(ctx context.Context, incoming []byte) error } -// ConnectionCandidate contains the connection setup details to be used when +// ConnectionWrapper contains the connection setup details to be used when // attempting a new connection. It also contains the subscriptions that are // associated with the specific connection. -type ConnectionCandidate struct { +type ConnectionWrapper struct { // Setup contains the connection setup details Setup *ConnectionSetup // Subscriptions contains the subscriptions that are associated with the diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index f4b81ba4c60..af6abc6fd73 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -92,7 +92,7 @@ func NewWebsocket() *Websocket { subscriptions: subscription.NewStore(), features: &protocol.Features{}, Orderbook: buffer.Orderbook{}, - connections: make(map[Connection]*ConnectionCandidate), + connections: make(map[Connection]*ConnectionWrapper), } } @@ -257,7 +257,7 @@ func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { } } - w.connectionManager = append(w.connectionManager, ConnectionCandidate{ + w.connectionManager = append(w.connectionManager, ConnectionWrapper{ Setup: c, Subscriptions: subscription.NewStore(), }) @@ -410,11 +410,17 @@ func (w *Websocket) Connect() error { break } + w.connections[conn] = &w.connectionManager[i] + w.connectionManager[i].Connection = conn + + if !conn.IsConnected() { + multiConnectFatalError = fmt.Errorf("%s websocket: [conn:%d] [URL:%s] failed to connect", w.exchangeName, i+1, conn.URL) + break + } + w.Wg.Add(1) go w.Reader(context.TODO(), conn, w.connectionManager[i].Setup.Handler) - w.connections[conn] = &w.connectionManager[i] - err = w.connectionManager[i].Setup.Subscriber(context.TODO(), conn, subs) if err != nil { multiConnectFatalError = fmt.Errorf("%v Error subscribing %w", w.exchangeName, err) @@ -428,17 +434,18 @@ func (w *Websocket) Connect() error { conn.URL, len(subs)) } - - w.connectionManager[i].Connection = conn } if multiConnectFatalError != nil { // Roll back any successful connections and flush subscriptions - for conn, candidate := range w.connections { - if err := conn.Shutdown(); err != nil { - log.Errorln(log.WebsocketMgr, err) + for x := range w.connectionManager { + if w.connectionManager[x].Connection != nil { + if err := w.connectionManager[x].Connection.Shutdown(); err != nil { + log.Errorln(log.WebsocketMgr, err) + } + w.connectionManager[x].Connection = nil } - candidate.Subscriptions.Clear() + w.connectionManager[x].Subscriptions.Clear() } clear(w.connections) w.setState(disconnectedState) // Flip from connecting to disconnected. @@ -731,7 +738,7 @@ func (w *Websocket) FlushChannels() error { } if len(newsubs) != 0 { // Purge subscription list as there will be conflicts - w.connections[w.connectionManager[x].Connection].Subscriptions.Clear() + w.connectionManager[x].Subscriptions.Clear() err = w.SubscribeToChannels(w.connectionManager[x].Connection, newsubs) if err != nil { return err diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 9df8749a2d3..4bf54617c97 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -203,15 +203,14 @@ func (w *WebsocketConnection) IsConnected() bool { // ReadMessage reads messages, can handle text, gzip and binary func (w *WebsocketConnection) ReadMessage() Response { - if w.Connection == nil { - return Response{} - } mType, resp, err := w.Connection.ReadMessage() if err != nil { - // Any error condition will return nil response which will return the - // reader routine and the connection will hang with no readers. This has - // to be handed over to w.readMessageErrors if there is an active - // connection. + // Any error condition will return a Response{Raw: nil, Type: 0} which + // will force the reader routine to return. The connection will hang + // with no reader routine and its buffer will be written to from the + // active websocket connection. This should be handed over to + // `w.readMessageErrors` and managed by 'connectionMonitor' which needs + // to flush, reconnect and resubscribe the connection. if w.setConnectedStatus(false) { // NOTE: When w.setConnectedStatus() returns true the underlying // state was changed and this infers that the connection was @@ -293,7 +292,11 @@ func (w *WebsocketConnection) Shutdown() error { return nil } w.setConnectedStatus(false) - return w.Connection.UnderlyingConn().Close() + w.writeControl.Lock() + defer w.writeControl.Unlock() + defer w.Connection.NetConn().Close() // Ungraceful close as backup + // Gracefully close the connection + return w.Connection.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) } // SetURL sets connection URL diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index d5c8dd72f13..8847796bee4 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -10,6 +10,7 @@ import ( "fmt" "net" "net/http" + "net/http/httptest" "os" "strconv" "strings" @@ -107,10 +108,15 @@ func (d *dodgyConnection) Connect() error { return fmt.Errorf("cannot connect: %w", errDastardlyReason) } +var mock *httptest.Server + func TestMain(m *testing.M) { // Change trafficCheckInterval for TestTrafficMonitorTimeout before parallel tests to avoid racing trafficCheckInterval = 50 * time.Millisecond - os.Exit(m.Run()) + mock = httptest.NewServer(http.HandlerFunc(websocketServerMockEcho)) + r := m.Run() + mock.Close() + os.Exit(r) } func TestSetup(t *testing.T) { @@ -359,7 +365,8 @@ func TestConnectionMessageErrors(t *testing.T) { err = ws.Connect() assert.ErrorIs(t, err, errNoPendingConnections, "Connect should error correctly") - ws.connectionManager = []ConnectionCandidate{{Setup: &ConnectionSetup{URL: "ws://localhost:8080/ws"}}} + ws.useMultiConnectionManagement = true + ws.connectionManager = []ConnectionWrapper{{Setup: &ConnectionSetup{URL: "ws" + mock.URL[len("http"):] + "/ws"}}} err = ws.Connect() require.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset) @@ -393,8 +400,8 @@ func TestConnectionMessageErrors(t *testing.T) { err = ws.Connect() require.ErrorIs(t, err, errDastardlyReason) - ws.connectionManager[0].Setup.Connector = func(context.Context, Connection) error { - return nil + ws.connectionManager[0].Setup.Connector = func(ctx context.Context, conn Connection) error { + return conn.DialContext(ctx, websocket.DefaultDialer, nil) } err = ws.Connect() require.ErrorIs(t, err, errDastardlyReason) @@ -603,7 +610,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { require.NoError(t, multi.SetupNewConnection(amazingCandidate)) amazingConn := multi.getConnectionFromSetup(amazingCandidate) - multi.connections = map[Connection]*ConnectionCandidate{ + multi.connections = map[Connection]*ConnectionWrapper{ amazingConn: &multi.connectionManager[0], } @@ -1102,7 +1109,7 @@ func TestGetChannelDifference(t *testing.T) { require.Equal(t, 1, len(subs)) require.Empty(t, unsubs, "Should get no unsubs") - w.connections = map[Connection]*ConnectionCandidate{ + w.connections = map[Connection]*ConnectionWrapper{ sweetConn: {Setup: &ConnectionSetup{URL: "ws://localhost:8080/ws"}}, } @@ -1302,17 +1309,7 @@ func TestSetupNewConnection(t *testing.T) { // Test connection candidates for multi connection tracking. multi := NewWebsocket() set := *defaultSetup - - // Values below are now not necessary as this will be set per connection - // candidate in SetupNewConnection. set.UseMultiConnectionManagement = true - set.Connector = nil - set.Subscriber = nil - set.Unsubscriber = nil - set.GenerateSubscriptions = nil - set.DefaultURL = "" - set.RunningURL = "" - require.NoError(t, multi.Setup(&set)) connSetup := &ConnectionSetup{} @@ -1448,3 +1445,36 @@ func TestCheckSubscriptions(t *testing.T) { err = ws.checkSubscriptions(nil, subscription.List{{}}) assert.NoError(t, err, "checkSubscriptions should not error") } + +// websocketServerMockEcho is a mock websocket server that echos messages back to the client +func websocketServerMockEcho(w http.ResponseWriter, r *http.Request) { + var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + fmt.Println("Upgrade error:", err) + return + } + defer conn.Close() + + for { + messageType, message, err := conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) { + panic(err) + } + break + } + err = conn.WriteMessage(messageType, message) + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) { + panic(err) + } + break + } + } +} diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 529798a36df..13b73630ab8 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -50,12 +50,12 @@ type Websocket struct { m sync.Mutex connector func() error - // ConnectionManager contains the connection candidates and the current + // connectionWrapper contains the connection wrappers and the current // connections - connectionManager []ConnectionCandidate + connectionManager []ConnectionWrapper // Connections contains the current connections with their associated // connection candidates - connections map[Connection]*ConnectionCandidate + connections map[Connection]*ConnectionWrapper subscriptions *subscription.Store From 03de669f9c01278fdfd9eccd8310010fe29eafcd Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Tue, 23 Jul 2024 09:45:11 +1000 Subject: [PATCH 025/138] linter: fix --- exchanges/stream/websocket_test.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 8847796bee4..d976c1e4fbe 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -1448,11 +1448,7 @@ func TestCheckSubscriptions(t *testing.T) { // websocketServerMockEcho is a mock websocket server that echos messages back to the client func websocketServerMockEcho(w http.ResponseWriter, r *http.Request) { - var upgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, - } + upgrader := websocket.Upgrader{CheckOrigin: func(_ *http.Request) bool { return true }} conn, err := upgrader.Upgrade(w, r, nil) if err != nil { From 8159b05417e8a2cc7d0fa5945b25b550ec287c15 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 24 Jul 2024 10:35:04 +1000 Subject: [PATCH 026/138] fix deadlock on connection as the previous channel had no reader and would hang connection reader for eternity. --- exchanges/stream/websocket.go | 2 +- exchanges/stream/websocket_test.go | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index af6abc6fd73..410a2b758d7 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -1277,7 +1277,7 @@ func (w *Websocket) Reader(ctx context.Context, conn Connection, handler func(ct return // Connection has been closed } if err := handler(ctx, resp.Raw); err != nil { - w.ReadMessageErrors <- err + w.DataHandler <- err } } } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index d976c1e4fbe..52976840441 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -407,7 +407,7 @@ func TestConnectionMessageErrors(t *testing.T) { require.ErrorIs(t, err, errDastardlyReason) ws.connectionManager[0].Setup.Handler = func(context.Context, []byte) error { - return nil + return errDastardlyReason } err = ws.Connect() require.ErrorIs(t, err, errDastardlyReason) @@ -416,6 +416,11 @@ func TestConnectionMessageErrors(t *testing.T) { return nil } err = ws.Connect() + require.NoError(t, err) + + err = ws.connectionManager[0].Connection.SendRawMessage(websocket.TextMessage, []byte("test")) + require.NoError(t, err) + require.NoError(t, err) require.NoError(t, ws.Shutdown()) } From 0e3bb312f8bbbef0a65b4e563db89bedc6a5410a Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 24 Jul 2024 10:40:44 +1000 Subject: [PATCH 027/138] glorious: whooops --- exchanges/stream/websocket_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index ca7e72429dd..dadc824668f 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -725,8 +725,7 @@ func TestSendMessage(t *testing.T) { } } -// TestSendMessageWithResponse logic test -func TestSendMessageWithResponse(t *testing.T) { +func TestSendMessageReturnResponse(t *testing.T) { t.Parallel() wc := &WebsocketConnection{ Verbose: true, From f98c3aa9d6bcd73dc0718ff829b1226d0c92f52b Mon Sep 17 00:00:00 2001 From: shazbert Date: Wed, 24 Jul 2024 14:23:34 +1000 Subject: [PATCH 028/138] gk: nits --- exchanges/binance/binance_test.go | 5 +- exchanges/kraken/kraken_test.go | 5 +- exchanges/stream/websocket_connection.go | 4 +- exchanges/stream/websocket_test.go | 38 ++------------- exchanges/stream/websocket_types.go | 6 +-- internal/testing/exchange/exchange.go | 33 ------------- internal/testing/exchange/exchange_test.go | 3 +- internal/testing/websocket/mock.go | 57 ++++++++++++++++++++++ internal/testing/websocket/mock_test.go | 1 + 9 files changed, 73 insertions(+), 79 deletions(-) create mode 100644 internal/testing/websocket/mock.go create mode 100644 internal/testing/websocket/mock_test.go diff --git a/exchanges/binance/binance_test.go b/exchanges/binance/binance_test.go index b4b95e1d5da..3aec0407627 100644 --- a/exchanges/binance/binance_test.go +++ b/exchanges/binance/binance_test.go @@ -31,6 +31,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" + mockws "github.com/thrasher-corp/gocryptotrader/internal/testing/websocket" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -1993,7 +1994,7 @@ func TestSubscribe(t *testing.T) { require.ElementsMatch(t, req.Params, exp, "Params should have correct channels") return w.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf(`{"result":null,"id":%d}`, req.ID))) } - b = testexch.MockWsInstance[Binance](t, testexch.CurryWsMockUpgrader(t, mock)) + b = testexch.MockWsInstance[Binance](t, mockws.CurryWsMockUpgrader(t, mock)) } else { testexch.SetupWs(t, b) } @@ -2014,7 +2015,7 @@ func TestSubscribeBadResp(t *testing.T) { require.NoError(t, err, "Unmarshal should not error") return w.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf(`{"result":{"error":"carrots"},"id":%d}`, req.ID))) } - b := testexch.MockWsInstance[Binance](t, testexch.CurryWsMockUpgrader(t, mock)) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + b := testexch.MockWsInstance[Binance](t, mockws.CurryWsMockUpgrader(t, mock)) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes err := b.Subscribe(channels) assert.ErrorIs(t, err, errUnknownError, "Subscribe should error errUnknownError") assert.ErrorContains(t, err, "carrots", "Subscribe should error containing the carrots") diff --git a/exchanges/kraken/kraken_test.go b/exchanges/kraken/kraken_test.go index c55eac95d00..c852265cc5b 100644 --- a/exchanges/kraken/kraken_test.go +++ b/exchanges/kraken/kraken_test.go @@ -32,6 +32,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" + mockws "github.com/thrasher-corp/gocryptotrader/internal/testing/websocket" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -2290,7 +2291,7 @@ func TestGetOpenInterest(t *testing.T) { } // curryWsMockUpgrader handles Kraken specific http auth token responses prior to handling off to standard Websocket upgrader -func curryWsMockUpgrader(tb testing.TB, h testexch.WsMockFunc) http.HandlerFunc { +func curryWsMockUpgrader(tb testing.TB, h mockws.WsMockFunc) http.HandlerFunc { tb.Helper() return func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "GetWebSocketsToken") { @@ -2298,7 +2299,7 @@ func curryWsMockUpgrader(tb testing.TB, h testexch.WsMockFunc) http.HandlerFunc require.NoError(tb, err, "Write should not error") return } - testexch.WsMockUpgrader(tb, w, r, h) + mockws.WsMockUpgrader(tb, w, r, h) } } diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 4bf54617c97..b421399dbf3 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -294,9 +294,7 @@ func (w *WebsocketConnection) Shutdown() error { w.setConnectedStatus(false) w.writeControl.Lock() defer w.writeControl.Unlock() - defer w.Connection.NetConn().Close() // Ungraceful close as backup - // Gracefully close the connection - return w.Connection.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + return w.Connection.NetConn().Close() } // SetURL sets connection URL diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 52976840441..0660b4e46b7 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -26,6 +26,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" + mockws "github.com/thrasher-corp/gocryptotrader/internal/testing/websocket" ) const ( @@ -108,14 +109,11 @@ func (d *dodgyConnection) Connect() error { return fmt.Errorf("cannot connect: %w", errDastardlyReason) } -var mock *httptest.Server - func TestMain(m *testing.M) { // Change trafficCheckInterval for TestTrafficMonitorTimeout before parallel tests to avoid racing trafficCheckInterval = 50 * time.Millisecond - mock = httptest.NewServer(http.HandlerFunc(websocketServerMockEcho)) r := m.Run() - mock.Close() + os.Exit(r) } @@ -366,6 +364,9 @@ func TestConnectionMessageErrors(t *testing.T) { assert.ErrorIs(t, err, errNoPendingConnections, "Connect should error correctly") ws.useMultiConnectionManagement = true + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.WsWithEcho) })) + defer mock.Close() ws.connectionManager = []ConnectionWrapper{{Setup: &ConnectionSetup{URL: "ws" + mock.URL[len("http"):] + "/ws"}}} err = ws.Connect() require.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset) @@ -1450,32 +1451,3 @@ func TestCheckSubscriptions(t *testing.T) { err = ws.checkSubscriptions(nil, subscription.List{{}}) assert.NoError(t, err, "checkSubscriptions should not error") } - -// websocketServerMockEcho is a mock websocket server that echos messages back to the client -func websocketServerMockEcho(w http.ResponseWriter, r *http.Request) { - upgrader := websocket.Upgrader{CheckOrigin: func(_ *http.Request) bool { return true }} - - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - fmt.Println("Upgrade error:", err) - return - } - defer conn.Close() - - for { - messageType, message, err := conn.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) { - panic(err) - } - break - } - err = conn.WriteMessage(messageType, message) - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) { - panic(err) - } - break - } - } -} diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 13b73630ab8..f663cb7c0cd 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -50,12 +50,8 @@ type Websocket struct { m sync.Mutex connector func() error - // connectionWrapper contains the connection wrappers and the current - // connections connectionManager []ConnectionWrapper - // Connections contains the current connections with their associated - // connection candidates - connections map[Connection]*ConnectionWrapper + connections map[Connection]*ConnectionWrapper subscriptions *subscription.Store diff --git a/internal/testing/exchange/exchange.go b/internal/testing/exchange/exchange.go index 85de1fea761..fe82b4f5435 100644 --- a/internal/testing/exchange/exchange.go +++ b/internal/testing/exchange/exchange.go @@ -14,7 +14,6 @@ import ( "sync" "testing" - "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/config" @@ -86,11 +85,6 @@ func MockHTTPInstance(e exchange.IBotExchange) error { return nil } -var upgrader = websocket.Upgrader{} - -// WsMockFunc is a websocket handler to be called with each websocket message -type WsMockFunc func([]byte, *websocket.Conn) error - // MockWsInstance creates a new Exchange instance with a mock websocket instance and HTTP server // It accepts an exchange package type argument and a http.HandlerFunc // See CurryWsMockUpgrader for a convenient way to curry t and a ws mock function @@ -128,33 +122,6 @@ func MockWsInstance[T any, PT interface { return e } -// CurryWsMockUpgrader curries a WsMockUpgrader with a testing.TB and a mock func -// bridging the gap between information known before the Server is created and during a request -func CurryWsMockUpgrader(tb testing.TB, wsHandler WsMockFunc) http.HandlerFunc { - tb.Helper() - return func(w http.ResponseWriter, r *http.Request) { - WsMockUpgrader(tb, w, r, wsHandler) - } -} - -// WsMockUpgrader handles upgrading an initial HTTP request to WS, and then runs a for loop calling the mock func on each input -func WsMockUpgrader(tb testing.TB, w http.ResponseWriter, r *http.Request, wsHandler WsMockFunc) { - tb.Helper() - c, err := upgrader.Upgrade(w, r, nil) - require.NoError(tb, err, "Upgrade connection should not error") - defer c.Close() - for { - _, p, err := c.ReadMessage() - if websocket.IsUnexpectedCloseError(err) { - return - } - require.NoError(tb, err, "ReadMessage should not error") - - err = wsHandler(p, c) - assert.NoError(tb, err, "WS Mock Function should not error") - } -} - // FixtureToDataHandler squirts the contents of a file to a reader function (probably e.wsHandleData) func FixtureToDataHandler(tb testing.TB, fixturePath string, reader func([]byte) error) { tb.Helper() diff --git a/internal/testing/exchange/exchange_test.go b/internal/testing/exchange/exchange_test.go index d6796c9ebbc..cc4ab52844a 100644 --- a/internal/testing/exchange/exchange_test.go +++ b/internal/testing/exchange/exchange_test.go @@ -9,6 +9,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/exchanges/binance" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" + mockws "github.com/thrasher-corp/gocryptotrader/internal/testing/websocket" ) // TestSetup exercises Setup @@ -30,6 +31,6 @@ func TestMockHTTPInstance(t *testing.T) { // TestMockWsInstance exercises MockWsInstance func TestMockWsInstance(t *testing.T) { - b := MockWsInstance[binance.Binance](t, CurryWsMockUpgrader(t, func(_ []byte, _ *websocket.Conn) error { return nil })) + b := MockWsInstance[binance.Binance](t, mockws.CurryWsMockUpgrader(t, func(_ []byte, _ *websocket.Conn) error { return nil })) require.NotNil(t, b, "MockWsInstance must not be nil") } diff --git a/internal/testing/websocket/mock.go b/internal/testing/websocket/mock.go new file mode 100644 index 00000000000..5905a29daad --- /dev/null +++ b/internal/testing/websocket/mock.go @@ -0,0 +1,57 @@ +package websocket + +import ( + "net/http" + "strings" + "testing" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var upgrader = websocket.Upgrader{CheckOrigin: func(_ *http.Request) bool { return true }} + +// WsMockFunc is a websocket handler to be called with each websocket message +type WsMockFunc func([]byte, *websocket.Conn) error + +// CurryWsMockUpgrader curries a WsMockUpgrader with a testing.TB and a mock func +// bridging the gap between information known before the Server is created and during a request +func CurryWsMockUpgrader(tb testing.TB, wsHandler WsMockFunc) http.HandlerFunc { + tb.Helper() + return func(w http.ResponseWriter, r *http.Request) { + WsMockUpgrader(tb, w, r, wsHandler) + } +} + +// WsMockUpgrader handles upgrading an initial HTTP request to WS, and then runs a for loop calling the mock func on each input +func WsMockUpgrader(tb testing.TB, w http.ResponseWriter, r *http.Request, wsHandler WsMockFunc) { + tb.Helper() + c, err := upgrader.Upgrade(w, r, nil) + require.NoError(tb, err, "Upgrade connection should not error") + defer c.Close() + for { + _, p, err := c.ReadMessage() + if websocket.IsUnexpectedCloseError(err) { + return + } + + if err != nil && strings.Contains(err.Error(), "wsarecv: An established connection was aborted by the software in your host machine.") { + return + } + + require.NoError(tb, err, "ReadMessage should not error") + + err = wsHandler(p, c) + assert.NoError(tb, err, "WS Mock Function should not error") + } +} + +// WsWithEcho is a simple echo function after a read +func WsWithEcho(p []byte, c *websocket.Conn) error { + err := c.WriteMessage(websocket.TextMessage, p) + if err != nil { + return err + } + return nil +} diff --git a/internal/testing/websocket/mock_test.go b/internal/testing/websocket/mock_test.go new file mode 100644 index 00000000000..708bc8cb5dd --- /dev/null +++ b/internal/testing/websocket/mock_test.go @@ -0,0 +1 @@ +package websocket From d89a46a932d4aab91817aaef1b3026d77cb4614a Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 25 Jul 2024 12:50:44 +1000 Subject: [PATCH 029/138] Leak issue and edge case --- exchanges/stream/websocket.go | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 410a2b758d7..5d64b1df8d2 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -83,11 +83,15 @@ func SetupGlobalReporter(r Reporter) { // NewWebsocket initialises the websocket struct func NewWebsocket() *Websocket { return &Websocket{ - DataHandler: make(chan interface{}, jobBuffer), - ToRoutine: make(chan interface{}, jobBuffer), - ShutdownC: make(chan struct{}), - TrafficAlert: make(chan struct{}, 1), - ReadMessageErrors: make(chan error), + DataHandler: make(chan interface{}, jobBuffer), + ToRoutine: make(chan interface{}, jobBuffer), + ShutdownC: make(chan struct{}), + TrafficAlert: make(chan struct{}, 1), + // ReadMessageErrors is buffered for an edge case when `Connect` fails + // after subscriptions are made but before the connectionMonitor has + // started. This allows the error to be read and handled in the + // connectionMonitor and start a connection cycle again. + ReadMessageErrors: make(chan error, 1), Match: NewMatch(), subscriptions: subscription.NewStore(), features: &protocol.Features{}, @@ -449,6 +453,14 @@ func (w *Websocket) Connect() error { } clear(w.connections) w.setState(disconnectedState) // Flip from connecting to disconnected. + + // Drain residual error in the single buffered channel, this mitigates + // the cycle when `Connect` is called again and the connectionMonitor + // starts but there is an old error in the channel. + select { + case <-w.ReadMessageErrors: + default: + } return multiConnectFatalError } @@ -646,6 +658,15 @@ func (w *Websocket) Shutdown() error { if w.verbose { log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", w.exchangeName) } + + // Drain residual error in the single buffered channel, this mitigates + // the cycle when `Connect` is called again and the connectionMonitor + // starts but there is an old error in the channel. + select { + case <-w.ReadMessageErrors: + default: + } + return nil } From 8281b88a4d773b9c8d2a06cd91561b09766d21c5 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Mon, 8 Apr 2024 19:51:41 +0700 Subject: [PATCH 030/138] Websocket: Add SendMessageReturnResponses --- exchanges/bitfinex/bitfinex_test.go | 4 ++ exchanges/stream/stream_match.go | 31 ++++---- exchanges/stream/stream_match_test.go | 40 ++++------- exchanges/stream/stream_types.go | 5 +- exchanges/stream/websocket.go | 2 + exchanges/stream/websocket_connection.go | 92 ++++++++++++++---------- exchanges/stream/websocket_test.go | 2 +- 7 files changed, 96 insertions(+), 80 deletions(-) diff --git a/exchanges/bitfinex/bitfinex_test.go b/exchanges/bitfinex/bitfinex_test.go index 86a1cfa7293..f66b665974c 100644 --- a/exchanges/bitfinex/bitfinex_test.go +++ b/exchanges/bitfinex/bitfinex_test.go @@ -1316,7 +1316,11 @@ func TestWsCancelOffer(t *testing.T) { } func TestWsSubscribedResponse(t *testing.T) { +<<<<<<< HEAD ch, err := b.Websocket.Match.Set("subscribe:waiter1") +======= + m, err := b.Websocket.Match.Set("subscribe:waiter1", 1) +>>>>>>> 507c12f1d (Websocket: Add SendMessageReturnResponses) assert.NoError(t, err, "Setting a matcher should not error") err = b.wsHandleData([]byte(`{"event":"subscribed","channel":"ticker","chanId":224555,"subId":"waiter1","symbol":"tBTCUSD","pair":"BTCUSD"}`)) if assert.Error(t, err, "Should error if sub is not registered yet") { diff --git a/exchanges/stream/stream_match.go b/exchanges/stream/stream_match.go index 5a024fc3f49..03f0340f2ef 100644 --- a/exchanges/stream/stream_match.go +++ b/exchanges/stream/stream_match.go @@ -9,7 +9,7 @@ var errSignatureCollision = errors.New("signature collision") // NewMatch returns a new Match func NewMatch() *Match { - return &Match{m: make(map[any]chan<- []byte)} + return &Match{m: make(map[any]Incoming)} } // Match is a distributed subtype that handles the matching of requests and @@ -17,13 +17,14 @@ func NewMatch() *Match { // connections. Stream systems fan in all incoming payloads to one routine for // processing. type Match struct { - m map[any]chan<- []byte + m map[any]Incoming mu sync.Mutex } -// Incoming matches with request, disregarding the returned payload -func (m *Match) Incoming(signature any) bool { - return m.IncomingWithData(signature, nil) +// Incoming is a sub-type that handles incoming data +type Incoming struct { + count int // Number of responses expected + waitingRoutine chan<- []byte } // IncomingWithData matches with requests and takes in the returned payload, to @@ -35,21 +36,27 @@ func (m *Match) IncomingWithData(signature any, data []byte) bool { if !ok { return false } - ch <- data - close(ch) - delete(m.m, signature) + ch.waitingRoutine <- data + ch.count-- + if ch.count == 0 { + close(ch.waitingRoutine) + delete(m.m, signature) + } return true } // Set the signature response channel for incoming data -func (m *Match) Set(signature any) (<-chan []byte, error) { +func (m *Match) Set(signature any, bufSize int) (<-chan []byte, error) { + if bufSize < 0 { + bufSize = 1 + } m.mu.Lock() defer m.mu.Unlock() if _, ok := m.m[signature]; ok { return nil, errSignatureCollision } - ch := make(chan []byte, 1) // This is buffered so we don't need to wait for receiver. - m.m[signature] = ch + ch := make(chan []byte, bufSize) + m.m[signature] = Incoming{count: bufSize, waitingRoutine: ch} return ch, nil } @@ -58,7 +65,7 @@ func (m *Match) RemoveSignature(signature any) { m.mu.Lock() defer m.mu.Unlock() if ch, ok := m.m[signature]; ok { - close(ch) + close(ch.waitingRoutine) delete(m.m, signature) } } diff --git a/exchanges/stream/stream_match_test.go b/exchanges/stream/stream_match_test.go index 3530c12e686..dc1b933bd2e 100644 --- a/exchanges/stream/stream_match_test.go +++ b/exchanges/stream/stream_match_test.go @@ -3,40 +3,28 @@ package stream import ( "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestMatch(t *testing.T) { t.Parallel() - nm := NewMatch() - require.False(t, nm.Incoming("wow")) + load := []byte("42") + assert.False(t, new(Match).IncomingWithData("hello", load), "Should not match an uninitilized Match") - // try to match with unset signature - require.False(t, nm.Incoming("hello")) + match := NewMatch() + assert.False(t, match.IncomingWithData("hello", load), "Should not match an empty signature") - ch, err := nm.Set("hello") - require.NoError(t, err) + ch, err := match.Set("hello", 2) + require.NoError(t, err, "Set must not error") + assert.True(t, match.IncomingWithData("hello", []byte("hello"))) + assert.Equal(t, "hello", string(<-ch)) - _, err = nm.Set("hello") - require.ErrorIs(t, err, errSignatureCollision) + _, err = match.Set("hello", 1) + assert.ErrorIs(t, err, errSignatureCollision, "Should error on signature collision") - // try and match with initial payload - require.True(t, nm.Incoming("hello")) - require.Nil(t, <-ch) + assert.True(t, match.IncomingWithData("hello", load), "Should match with matching message and signature") + assert.True(t, match.IncomingWithData("hello", load), "Should match with matching message and signature") - // put in secondary payload with conflicting signature - require.False(t, nm.Incoming("hello")) - - ch, err = nm.Set("hello") - require.NoError(t, err) - - expected := []byte("payload") - require.True(t, nm.IncomingWithData("hello", expected)) - - require.Equal(t, expected, <-ch) - - _, err = nm.Set("purge me") - require.NoError(t, err) - nm.RemoveSignature("purge me") - require.False(t, nm.IncomingWithData("purge me", expected)) + assert.Len(t, ch, 2, "Channel should have 2 items") } diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index 578beb2f0d8..9421950cc85 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -15,10 +15,11 @@ import ( type Connection interface { Dial(*websocket.Dialer, http.Header) error ReadMessage() Response - SendJSONMessage(interface{}) error + SendJSONMessage(any) error SetupPingHandler(PingHandler) GenerateMessageID(highPrecision bool) int64 - SendMessageReturnResponse(ctx context.Context, signature interface{}, request interface{}) ([]byte, error) + SendMessageReturnResponse(ctx context.Context, signature any, request any) ([]byte, error) + SendMessageReturnResponses(ctx context.Context, signature any, request any, expected int) ([][]byte, error) SendRawMessage(messageType int, message []byte) error SetURL(string) SetProxy(string) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index badb5d565a4..b310ca5bb05 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -29,6 +29,8 @@ var ( ErrUnsubscribeFailure = errors.New("unsubscribe failure") ErrAlreadyDisabled = errors.New("websocket already disabled") ErrNotConnected = errors.New("websocket is not connected") + ErrNoMessageListener = errors.New("websocket listener not found for message") + ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature") ) // Private websocket errors diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index de62442b3ef..e6ea2ffa89e 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -7,7 +7,6 @@ import ( "context" "crypto/rand" "encoding/json" - "errors" "fmt" "io" "math/big" @@ -21,44 +20,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/log" ) -var errMatchTimeout = errors.New("websocket connection: timeout waiting for response with signature") - -// SendMessageReturnResponse will send a WS message to the connection and wait -// for response -func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, signature, request interface{}) ([]byte, error) { - outbound, err := json.Marshal(request) - if err != nil { - return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err) - } - - ch, err := w.Match.Set(signature) - if err != nil { - return nil, err - } - - start := time.Now() - err = w.SendRawMessage(websocket.TextMessage, outbound) - if err != nil { - return nil, err - } - - timer := time.NewTimer(w.ResponseMaxLimit) - select { - case payload := <-ch: - timer.Stop() - if w.Reporter != nil { - w.Reporter.Latency(w.ExchangeName, outbound, time.Since(start)) - } - return payload, nil - case <-timer.C: - w.Match.RemoveSignature(signature) - return nil, fmt.Errorf("%s %w: %v", w.ExchangeName, errMatchTimeout, signature) - case <-ctx.Done(): - w.Match.RemoveSignature(signature) - return nil, ctx.Err() - } -} - // Dial sets proxy urls and then connects to the websocket func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header) error { if w.ProxyURL != "" { @@ -308,3 +269,56 @@ func (w *WebsocketConnection) SetProxy(proxy string) { func (w *WebsocketConnection) GetURL() string { return w.URL } + +// SendMessageReturnResponse will send a WS message to the connection and wait for response +func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, signature, request any) ([]byte, error) { + resps, err := w.SendMessageReturnResponses(ctx, signature, request, 1) + if err != nil { + return nil, err + } + return resps[0], nil +} + +// SendMessageReturnResponses will send a WS message to the connection and wait for N responses +// An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked +func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, signature, request any, expected int) ([][]byte, error) { + outbound, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err) + } + + ch, err := w.Match.Set(signature, expected) + if err != nil { + return nil, err + } + + start := time.Now() + err = w.SendRawMessage(websocket.TextMessage, outbound) + if err != nil { + return nil, err + } + + timeout := time.NewTimer(w.ResponseMaxLimit * time.Duration(expected)) + + resps := make([][]byte, 0, expected) + for err == nil && len(resps) < expected { + select { + case resp := <-ch: + resps = append(resps, resp) + case <-timeout.C: + w.Match.RemoveSignature(signature) + err = fmt.Errorf("%s %w %v", w.ExchangeName, ErrSignatureTimeout, signature) + case <-ctx.Done(): + w.Match.RemoveSignature(signature) + err = ctx.Err() + } + } + + timeout.Stop() + + if err == nil && w.Reporter != nil { + w.Reporter.Latency(w.ExchangeName, outbound, time.Since(start)) + } + + return resps, err +} diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index dadc824668f..840150fb9fb 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -766,7 +766,7 @@ func TestSendMessageReturnResponse(t *testing.T) { // with timeout wc.ResponseMaxLimit = 1 _, err = wc.SendMessageReturnResponse(context.Background(), "123", request) - assert.ErrorIs(t, err, errMatchTimeout, "SendMessageReturnResponse should error when request ID not found") + assert.ErrorIs(t, err, ErrSignatureTimeout, "SendMessageReturnResponse should error when request ID not found") } type reporter struct { From a38a1d1b06b4c88a504475791a883d30ab13da1e Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 25 Jul 2024 15:01:34 +1000 Subject: [PATCH 031/138] whooooooopsie --- exchanges/bitfinex/bitfinex_test.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/exchanges/bitfinex/bitfinex_test.go b/exchanges/bitfinex/bitfinex_test.go index f66b665974c..20ef7eef0ed 100644 --- a/exchanges/bitfinex/bitfinex_test.go +++ b/exchanges/bitfinex/bitfinex_test.go @@ -1316,11 +1316,7 @@ func TestWsCancelOffer(t *testing.T) { } func TestWsSubscribedResponse(t *testing.T) { -<<<<<<< HEAD - ch, err := b.Websocket.Match.Set("subscribe:waiter1") -======= - m, err := b.Websocket.Match.Set("subscribe:waiter1", 1) ->>>>>>> 507c12f1d (Websocket: Add SendMessageReturnResponses) + ch, err := b.Websocket.Match.Set("subscribe:waiter1", 1) assert.NoError(t, err, "Setting a matcher should not error") err = b.wsHandleData([]byte(`{"event":"subscribed","channel":"ticker","chanId":224555,"subId":"waiter1","symbol":"tBTCUSD","pair":"BTCUSD"}`)) if assert.Error(t, err, "Should error if sub is not registered yet") { From 77ef3664e490fe6825d96fe3416d0c7a15e439c1 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 26 Jul 2024 15:24:09 +1000 Subject: [PATCH 032/138] gk: nitssssss --- exchanges/stream/stream_match.go | 31 ++++++++++++++------------- exchanges/stream/stream_match_test.go | 2 ++ 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/exchanges/stream/stream_match.go b/exchanges/stream/stream_match.go index 03f0340f2ef..c94571f4c2d 100644 --- a/exchanges/stream/stream_match.go +++ b/exchanges/stream/stream_match.go @@ -5,11 +5,14 @@ import ( "sync" ) -var errSignatureCollision = errors.New("signature collision") +var ( + errSignatureCollision = errors.New("signature collision") + errBufferShouldBeGreaterThanZero = errors.New("buffer size should be greater than zero") +) // NewMatch returns a new Match func NewMatch() *Match { - return &Match{m: make(map[any]Incoming)} + return &Match{m: make(map[any]incoming)} } // Match is a distributed subtype that handles the matching of requests and @@ -17,14 +20,13 @@ func NewMatch() *Match { // connections. Stream systems fan in all incoming payloads to one routine for // processing. type Match struct { - m map[any]Incoming + m map[any]incoming mu sync.Mutex } -// Incoming is a sub-type that handles incoming data -type Incoming struct { - count int // Number of responses expected - waitingRoutine chan<- []byte +type incoming struct { + expected int + c chan<- []byte } // IncomingWithData matches with requests and takes in the returned payload, to @@ -36,10 +38,9 @@ func (m *Match) IncomingWithData(signature any, data []byte) bool { if !ok { return false } - ch.waitingRoutine <- data - ch.count-- - if ch.count == 0 { - close(ch.waitingRoutine) + ch.c <- data + if ch.expected--; ch.expected == 0 { + close(ch.c) delete(m.m, signature) } return true @@ -47,8 +48,8 @@ func (m *Match) IncomingWithData(signature any, data []byte) bool { // Set the signature response channel for incoming data func (m *Match) Set(signature any, bufSize int) (<-chan []byte, error) { - if bufSize < 0 { - bufSize = 1 + if bufSize <= 0 { + return nil, errBufferShouldBeGreaterThanZero } m.mu.Lock() defer m.mu.Unlock() @@ -56,7 +57,7 @@ func (m *Match) Set(signature any, bufSize int) (<-chan []byte, error) { return nil, errSignatureCollision } ch := make(chan []byte, bufSize) - m.m[signature] = Incoming{count: bufSize, waitingRoutine: ch} + m.m[signature] = incoming{expected: bufSize, c: ch} return ch, nil } @@ -65,7 +66,7 @@ func (m *Match) RemoveSignature(signature any) { m.mu.Lock() defer m.mu.Unlock() if ch, ok := m.m[signature]; ok { - close(ch.waitingRoutine) + close(ch.c) delete(m.m, signature) } } diff --git a/exchanges/stream/stream_match_test.go b/exchanges/stream/stream_match_test.go index dc1b933bd2e..726e6e15a96 100644 --- a/exchanges/stream/stream_match_test.go +++ b/exchanges/stream/stream_match_test.go @@ -15,6 +15,8 @@ func TestMatch(t *testing.T) { match := NewMatch() assert.False(t, match.IncomingWithData("hello", load), "Should not match an empty signature") + _, err := match.Set("hello", -0) + require.ErrorIs(t, err, errBufferShouldBeGreaterThanZero, "Should error on buffer size less than 0") ch, err := match.Set("hello", 2) require.NoError(t, err, "Set must not error") assert.True(t, match.IncomingWithData("hello", []byte("hello"))) From 6341ccf614fdf7659ce631d8be22217b1748ad39 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 26 Jul 2024 15:32:08 +1000 Subject: [PATCH 033/138] Update exchanges/stream/stream_match.go Co-authored-by: Gareth Kirwan --- exchanges/stream/stream_match.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exchanges/stream/stream_match.go b/exchanges/stream/stream_match.go index c94571f4c2d..94ed0109d35 100644 --- a/exchanges/stream/stream_match.go +++ b/exchanges/stream/stream_match.go @@ -7,7 +7,7 @@ import ( var ( errSignatureCollision = errors.New("signature collision") - errBufferShouldBeGreaterThanZero = errors.New("buffer size should be greater than zero") + errInvalidBufferSize = errors.New("buffer size must be positive") ) // NewMatch returns a new Match From c0d1d43346e543953fd03574b566756ccb1f0159 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 26 Jul 2024 15:32:15 +1000 Subject: [PATCH 034/138] Update exchanges/stream/stream_match_test.go Co-authored-by: Gareth Kirwan --- exchanges/stream/stream_match_test.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/exchanges/stream/stream_match_test.go b/exchanges/stream/stream_match_test.go index 726e6e15a96..df51fabb04e 100644 --- a/exchanges/stream/stream_match_test.go +++ b/exchanges/stream/stream_match_test.go @@ -15,8 +15,10 @@ func TestMatch(t *testing.T) { match := NewMatch() assert.False(t, match.IncomingWithData("hello", load), "Should not match an empty signature") - _, err := match.Set("hello", -0) - require.ErrorIs(t, err, errBufferShouldBeGreaterThanZero, "Should error on buffer size less than 0") + _, err := match.Set("hello", 0) + require.ErrorIs(t, err, errInvalidBufferSize, "Must error on zero buffer size") + _, err := match.Set("hello", -1) + require.ErrorIs(t, err, errInvalidBufferSize, "Must error on negative buffer size") ch, err := match.Set("hello", 2) require.NoError(t, err, "Set must not error") assert.True(t, match.IncomingWithData("hello", []byte("hello"))) From 9e57e6548bace07f8db7bd3928ec0ac189e1ba64 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 26 Jul 2024 15:57:25 +1000 Subject: [PATCH 035/138] linter: appease the linter gods --- exchanges/stream/stream_match.go | 9 +++++---- exchanges/stream/stream_match_test.go | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/exchanges/stream/stream_match.go b/exchanges/stream/stream_match.go index 94ed0109d35..6f934503da8 100644 --- a/exchanges/stream/stream_match.go +++ b/exchanges/stream/stream_match.go @@ -6,8 +6,8 @@ import ( ) var ( - errSignatureCollision = errors.New("signature collision") - errInvalidBufferSize = errors.New("buffer size must be positive") + errSignatureCollision = errors.New("signature collision") + errInvalidBufferSize = errors.New("buffer size must be positive") ) // NewMatch returns a new Match @@ -39,7 +39,8 @@ func (m *Match) IncomingWithData(signature any, data []byte) bool { return false } ch.c <- data - if ch.expected--; ch.expected == 0 { + ch.expected-- + if ch.expected == 0 { close(ch.c) delete(m.m, signature) } @@ -49,7 +50,7 @@ func (m *Match) IncomingWithData(signature any, data []byte) bool { // Set the signature response channel for incoming data func (m *Match) Set(signature any, bufSize int) (<-chan []byte, error) { if bufSize <= 0 { - return nil, errBufferShouldBeGreaterThanZero + return nil, errInvalidBufferSize } m.mu.Lock() defer m.mu.Unlock() diff --git a/exchanges/stream/stream_match_test.go b/exchanges/stream/stream_match_test.go index df51fabb04e..52acd2f95f0 100644 --- a/exchanges/stream/stream_match_test.go +++ b/exchanges/stream/stream_match_test.go @@ -17,7 +17,7 @@ func TestMatch(t *testing.T) { _, err := match.Set("hello", 0) require.ErrorIs(t, err, errInvalidBufferSize, "Must error on zero buffer size") - _, err := match.Set("hello", -1) + _, err = match.Set("hello", -1) require.ErrorIs(t, err, errInvalidBufferSize, "Must error on negative buffer size") ch, err := match.Set("hello", 2) require.NoError(t, err, "Set must not error") From 51b9f17afcf3b6fd072fd7a039ec750ee72fdd1b Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 26 Jul 2024 16:02:11 +1000 Subject: [PATCH 036/138] gk: nits --- exchanges/stream/websocket_test.go | 10 +--------- internal/testing/websocket/mock.go | 4 ++-- internal/testing/websocket/mock_test.go | 1 - 3 files changed, 3 insertions(+), 12 deletions(-) delete mode 100644 internal/testing/websocket/mock_test.go diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 0660b4e46b7..0ffb0e69951 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -365,7 +365,7 @@ func TestConnectionMessageErrors(t *testing.T) { ws.useMultiConnectionManagement = true - mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.WsWithEcho) })) + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) defer mock.Close() ws.connectionManager = []ConnectionWrapper{{Setup: &ConnectionSetup{URL: "ws" + mock.URL[len("http"):] + "/ws"}}} err = ws.Connect() @@ -590,15 +590,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { multi := NewWebsocket() set := *defaultSetup - // Values below are now not necessary as this will be set per connection - // candidate in SetupNewConnection. set.UseMultiConnectionManagement = true - set.Connector = nil - set.Subscriber = nil - set.Unsubscriber = nil - set.GenerateSubscriptions = nil - set.DefaultURL = "" - set.RunningURL = "" assert.NoError(t, multi.Setup(&set)) amazingCandidate := &ConnectionSetup{ diff --git a/internal/testing/websocket/mock.go b/internal/testing/websocket/mock.go index 5905a29daad..ce5f5dc0ab7 100644 --- a/internal/testing/websocket/mock.go +++ b/internal/testing/websocket/mock.go @@ -47,8 +47,8 @@ func WsMockUpgrader(tb testing.TB, w http.ResponseWriter, r *http.Request, wsHan } } -// WsWithEcho is a simple echo function after a read -func WsWithEcho(p []byte, c *websocket.Conn) error { +// EchoHandler is a simple echo function after a read +func EchoHandler(p []byte, c *websocket.Conn) error { err := c.WriteMessage(websocket.TextMessage, p) if err != nil { return err diff --git a/internal/testing/websocket/mock_test.go b/internal/testing/websocket/mock_test.go deleted file mode 100644 index 708bc8cb5dd..00000000000 --- a/internal/testing/websocket/mock_test.go +++ /dev/null @@ -1 +0,0 @@ -package websocket From 6376a12b18b7db8ab6d7ab68864af89e8ab06342 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 26 Jul 2024 16:07:49 +1000 Subject: [PATCH 037/138] gk: drain brain --- exchanges/stream/websocket.go | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 5d64b1df8d2..84210582f02 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -457,10 +457,8 @@ func (w *Websocket) Connect() error { // Drain residual error in the single buffered channel, this mitigates // the cycle when `Connect` is called again and the connectionMonitor // starts but there is an old error in the channel. - select { - case <-w.ReadMessageErrors: - default: - } + drain(w.ReadMessageErrors) + return multiConnectFatalError } @@ -662,11 +660,7 @@ func (w *Websocket) Shutdown() error { // Drain residual error in the single buffered channel, this mitigates // the cycle when `Connect` is called again and the connectionMonitor // starts but there is an old error in the channel. - select { - case <-w.ReadMessageErrors: - default: - } - + drain(w.ReadMessageErrors) return nil } @@ -1302,3 +1296,13 @@ func (w *Websocket) Reader(ctx context.Context, conn Connection, handler func(ct } } } + +func drain(ch <-chan error) { + for { + select { + case <-ch: + default: + return + } + } +} From f7adad2a733da0485b9046df76b744200ce1924a Mon Sep 17 00:00:00 2001 From: shazbert Date: Sun, 28 Jul 2024 18:29:24 +1000 Subject: [PATCH 038/138] started --- exchanges/gateio/gateio_test.go | 30 ++--- exchanges/gateio/gateio_websocket.go | 7 + exchanges/gateio/gateio_wrapper.go | 1 + exchanges/gateio/websocket_request.go | 148 +++++++++++++++++++++ exchanges/gateio/websocket_request_test.go | 35 +++++ exchanges/stream/stream_types.go | 4 + exchanges/stream/websocket.go | 52 ++++++++ 7 files changed, 262 insertions(+), 15 deletions(-) create mode 100644 exchanges/gateio/websocket_request.go create mode 100644 exchanges/gateio/websocket_request_test.go diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index cc785baf784..0d42f5e4492 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -144,21 +144,21 @@ func TestGetAccountInfo(t *testing.T) { if err != nil { t.Error("GetAccountInfo() error", err) } - if _, err := g.UpdateAccountInfo(context.Background(), asset.Margin); err != nil { - t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) - } - if _, err := g.UpdateAccountInfo(context.Background(), asset.CrossMargin); err != nil { - t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) - } - if _, err := g.UpdateAccountInfo(context.Background(), asset.Options); err != nil { - t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) - } - if _, err := g.UpdateAccountInfo(context.Background(), asset.Futures); err != nil { - t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) - } - if _, err := g.UpdateAccountInfo(context.Background(), asset.DeliveryFutures); err != nil { - t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) - } + // if _, err := g.UpdateAccountInfo(context.Background(), asset.Margin); err != nil { + // t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) + // } + // if _, err := g.UpdateAccountInfo(context.Background(), asset.CrossMargin); err != nil { + // t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) + // } + // if _, err := g.UpdateAccountInfo(context.Background(), asset.Options); err != nil { + // t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) + // } + // if _, err := g.UpdateAccountInfo(context.Background(), asset.Futures); err != nil { + // t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) + // } + // if _, err := g.UpdateAccountInfo(context.Background(), asset.DeliveryFutures); err != nil { + // t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) + // } } func TestWithdraw(t *testing.T) { diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 9272eda6336..8e2a3def1d4 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/buger/jsonparser" "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/currency" @@ -94,6 +95,12 @@ func (g *Gateio) generateWsSignature(secret, event, channel string, dtime time.T // WsHandleSpotData handles spot data func (g *Gateio) WsHandleSpotData(_ context.Context, respRaw []byte) error { + if requestID, err := jsonparser.GetString(respRaw, "request_id"); err == nil && requestID != "" { + if !g.Websocket.Match.IncomingWithData(requestID, respRaw) { + return fmt.Errorf("gateio_websocket.go error - unable to match requestID %v", requestID) + } + } + var push WsResponse err := json.Unmarshal(respRaw, &push) if err != nil { diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 262a81aa10f..3e79b24ee13 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -213,6 +213,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Unsubscriber: g.SpotUnsubscribe, GenerateSubscriptions: g.GenerateDefaultSubscriptionsSpot, Connector: g.WsConnectSpot, + AllowOutbound: true, }) if err != nil { return err diff --git a/exchanges/gateio/websocket_request.go b/exchanges/gateio/websocket_request.go new file mode 100644 index 00000000000..f07602019c0 --- /dev/null +++ b/exchanges/gateio/websocket_request.go @@ -0,0 +1,148 @@ +package gateio + +import ( + "context" + "crypto/hmac" + "crypto/sha512" + "encoding/hex" + "encoding/json" + "fmt" + "strconv" + "time" + + "github.com/buger/jsonparser" + "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/exchanges/asset" + "github.com/thrasher-corp/gocryptotrader/exchanges/order" +) + +// WebsocketRequest defines a websocket request +type WebsocketRequest struct { + App string `json:"app,omitempty"` + Time int64 `json:"time,omitempty"` + ID int64 `json:"id,omitempty"` + Channel string `json:"channel"` + Event string `json:"event"` + Payload WebsocketPayload `json:"payload"` +} + +// WebsocketPayload defines an individualised websocket payload +type WebsocketPayload struct { + APIKey string `json:"api_key,omitempty"` + Signature string `json:"signature,omitempty"` + Timestamp string `json:"timestamp,omitempty"` + RequestID string `json:"req_id,omitempty"` + RequestParam []byte `json:"req_param,omitempty"` +} + +// WebsocketResponse defines a websocket response +type WebsocketResponse struct { + RequestID string `json:"req_id"` + Acknowleged bool `json:"ack"` + Header WebsocketHeader `json:"req_header"` + Data json.RawMessage `json:"data"` +} + +// WebsocketHeader defines a websocket header +type WebsocketHeader struct { + ResponseTime int64 `json:"response_time"` + Status string `json:"status"` + Channel string `json:"channel"` + Event string `json:"event"` + ClientID string `json:"client_id"` +} + +type WebsocketErrors struct { + Label string `json:"label"` + Message string `json:"message"` +} + +// GetRoute returns the route for a websocket request, this is a POC +// for the websocket wrapper. +func (g *Gateio) GetRoute(a asset.Item) (string, error) { + switch a { + case asset.Spot: + return gateioWebsocketEndpoint, nil + default: + return "", common.ErrNotYetImplemented + } +} + +// LoginResult defines a login result +type LoginResult struct { + APIKey string `json:"api_key"` + UID string `json:"uid"` +} + +// WebsocketLogin logs in to the websocket +func (g *Gateio) WebsocketLogin(ctx context.Context, a asset.Item) (*LoginResult, error) { + + route, err := g.GetRoute(a) + if err != nil { + return nil, err + } + var resp *LoginResult + err = g.SendWebsocketRequest(ctx, "spot.login", route, nil, &resp) + return resp, err +} + +// OrderPlace +// OrderCancel +// OrderCancelAllByIDList +// OrderCancelAllByPair +// OrderAmend +// OrderStatus + +func (g *Gateio) OrderPlace(ctx context.Context, batch []order.Submit) string { + return "" +} + +// SendWebsocketRequest sends a websocket request to the exchange +func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel, route string, params []byte, result any) error { + creds, err := g.GetCredentials(ctx) + if err != nil { + return err + } + + tn := time.Now() + msg := "api\n" + channel + "\n" + string(params) + "\n" + strconv.FormatInt(tn.Unix(), 10) + mac := hmac.New(sha512.New, []byte(creds.Secret)) + if _, err := mac.Write([]byte(msg)); err != nil { + return err + } + signature := hex.EncodeToString(mac.Sum(nil)) + + outbound := WebsocketRequest{ + Time: tn.Unix(), + Channel: channel, + Event: "api", + Payload: WebsocketPayload{ + RequestID: strconv.FormatInt(tn.UnixNano(), 10), + APIKey: creds.Key, + RequestParam: params, + Signature: signature, + Timestamp: strconv.FormatInt(tn.Unix(), 10), + }, + } + + var inbound WebsocketResponse + err = g.Websocket.SendRequest(ctx, route, outbound.Payload.RequestID, &outbound, &inbound) + if err != nil { + return err + } + if fail, dataType, _, _ := jsonparser.Get(inbound.Data, "errs"); dataType != jsonparser.NotExist { + var wsErr WebsocketErrors + err := json.Unmarshal(fail, &wsErr) + if err != nil { + return err + } + return fmt.Errorf("gateio websocket error: %s %s", wsErr.Label, wsErr.Message) + } + + nested, _, _, err := jsonparser.Get(inbound.Data, "result") + if err != nil { + return err + } + + return json.Unmarshal(nested, result) +} diff --git a/exchanges/gateio/websocket_request_test.go b/exchanges/gateio/websocket_request_test.go new file mode 100644 index 00000000000..dc78337e251 --- /dev/null +++ b/exchanges/gateio/websocket_request_test.go @@ -0,0 +1,35 @@ +package gateio + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/exchanges/asset" +) + +func TestWebsocketLogin(t *testing.T) { + t.Parallel() + _, err := g.WebsocketLogin(context.Background(), asset.Futures) + require.ErrorIs(t, err, common.ErrNotYetImplemented) + + require.NoError(t, g.UpdateTradablePairs(context.Background(), false)) + for _, a := range g.GetAssetTypes(true) { + avail, err := g.GetAvailablePairs(a) + require.NoError(t, err) + if len(avail) > 1 { + avail = avail[:1] + } + require.NoError(t, g.SetPairs(avail, a, true)) + } + require.NoError(t, g.Websocket.Connect()) + g.GetBase().API.AuthenticatedSupport = true + g.GetBase().API.AuthenticatedWebsocketSupport = true + + got, err := g.WebsocketLogin(context.Background(), asset.Spot) + require.NoError(t, err) + + fmt.Println(got) +} diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index 246b729b93b..a2ced830a6e 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -63,6 +63,10 @@ type ConnectionSetup struct { // received from the exchange's websocket server. This function should // handle the incoming message and pass it to the appropriate data handler. Handler func(ctx context.Context, incoming []byte) error + // AllowOutbound is a flag that determines if the connection is allowed to + // send messages to the exchange's websocket server. This will allow the + // connection to be established without subscriptions needing to be made. + AllowOutbound bool } // ConnectionWrapper contains the connection setup details to be used when diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 84210582f02..4883fe783b7 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -2,6 +2,7 @@ package stream import ( "context" + "encoding/json" "errors" "fmt" "net" @@ -1306,3 +1307,54 @@ func drain(ch <-chan error) { } } } + +var ErrRequestRouteNotFound = errors.New("request route not found") +var ErrRequestRouteNotSet = errors.New("request route not set") +var ErrSignatureNotSet = errors.New("signature not set") +var ErrRequestPayloadNotSet = errors.New("request payload not set") + +// SendRequest sends a request to a specific route and unmarhsals the response into the result +func (w *Websocket) SendRequest(_ context.Context, routeID string, signature, payload, result any) error { + if w == nil { + return fmt.Errorf("%w: Websocket", common.ErrNilPointer) + } + if routeID == "" { + return ErrRequestRouteNotSet + } + if signature == nil { + return ErrSignatureNotSet + } + if payload == nil { + return ErrRequestPayloadNotSet + } + if !w.IsConnected() { + return ErrNotConnected + } + + for x := range w.connectionManager { + if w.connectionManager[x].Setup.URL != routeID { + continue + } + + if w.connectionManager[x].Connection == nil { + return fmt.Errorf("%s: %w", w.connectionManager[x].Setup.URL, ErrNotConnected) + } + + // if w.verbose { + display, _ := json.Marshal(payload) + log.Debugf(log.WebsocketMgr, "%s websocket: sending request to %s. Data: %v", w.exchangeName, routeID, string(display)) + // } + + resp, err := w.connectionManager[x].Connection.SendMessageReturnResponse(signature, payload) + if err != nil { + return err + } + + // if w.verbose { + log.Debugf(log.WebsocketMgr, "%s websocket: received response from %s. Data: %s", w.exchangeName, routeID, resp) + // } + return json.Unmarshal(resp, result) + } + + return fmt.Errorf("%w: %s", ErrRequestRouteNotFound, routeID) +} From 99056862e7532368feedb3eaf8a444d95387911f Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 31 Jul 2024 13:24:02 +1000 Subject: [PATCH 039/138] more changes before merge match pr --- engine/engine.go | 45 +--- exchanges/exchange.go | 8 +- exchanges/gateio/gateio_test.go | 2 +- exchanges/gateio/gateio_websocket.go | 8 + exchanges/gateio/gateio_wrapper.go | 1 + exchanges/gateio/websocket_request.go | 272 ++++++++++++++++----- exchanges/gateio/websocket_request_test.go | 59 ++++- exchanges/stream/stream_types.go | 5 + exchanges/stream/websocket.go | 14 ++ 9 files changed, 316 insertions(+), 98 deletions(-) diff --git a/engine/engine.go b/engine/engine.go index 20f88cbd4cb..ef658fdb7ee 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -791,17 +791,10 @@ func (bot *Engine) LoadExchange(name string) error { localWG.Wait() if !bot.Settings.EnableExchangeHTTPRateLimiter { - gctlog.Warnf(gctlog.ExchangeSys, - "Loaded exchange %s rate limiting has been turned off.\n", - exch.GetName(), - ) + gctlog.Warnf(gctlog.ExchangeSys, "Loaded exchange %s rate limiting has been turned off.\n", exch.GetName()) err = exch.DisableRateLimiter() if err != nil { - gctlog.Errorf(gctlog.ExchangeSys, - "Loaded exchange %s rate limiting cannot be turned off: %s.\n", - exch.GetName(), - err, - ) + gctlog.Errorf(gctlog.ExchangeSys, "Loaded exchange %s rate limiting cannot be turned off: %s.\n", exch.GetName(), err) } } @@ -820,29 +813,18 @@ func (bot *Engine) LoadExchange(name string) error { return err } - base := exch.GetBase() - if base.API.AuthenticatedSupport || - base.API.AuthenticatedWebsocketSupport { - assetTypes := base.GetAssetTypes(false) - var useAsset asset.Item - for a := range assetTypes { - err = base.CurrencyPairs.IsAssetEnabled(assetTypes[a]) - if err != nil { - continue - } - useAsset = assetTypes[a] - break - } - err = exch.ValidateAPICredentials(context.TODO(), useAsset) + b := exch.GetBase() + if b.API.AuthenticatedSupport || b.API.AuthenticatedWebsocketSupport { + err = exch.ValidateAPICredentials(context.TODO(), asset.Spot) if err != nil { - gctlog.Warnf(gctlog.ExchangeSys, - "%s: Cannot validate credentials, authenticated support has been disabled, Error: %s\n", - base.Name, - err) - base.API.AuthenticatedSupport = false - base.API.AuthenticatedWebsocketSupport = false + gctlog.Warnf(gctlog.ExchangeSys, "%s: Cannot validate credentials, authenticated support has been disabled, Error: %s", b.Name, err) + b.API.AuthenticatedSupport = false + b.API.AuthenticatedWebsocketSupport = false exchCfg.API.AuthenticatedSupport = false exchCfg.API.AuthenticatedWebsocketSupport = false + if b.Websocket != nil { + b.Websocket.SetCanUseAuthenticatedEndpoints(false) + } } } @@ -855,10 +837,7 @@ func (bot *Engine) dryRunParamInteraction(param string) { } if !bot.Settings.EnableDryRun { - gctlog.Warnf(gctlog.Global, - "Command line argument '-%s' induces dry run mode."+ - " Set -dryrun=false if you wish to override this.", - param) + gctlog.Warnf(gctlog.Global, "Command line argument '-%s' induces dry run mode. Set -dryrun=false if you wish to override this.", param) bot.Settings.EnableDryRun = true } } diff --git a/exchanges/exchange.go b/exchanges/exchange.go index 75a27905c4b..c9d80ba3455 100644 --- a/exchanges/exchange.go +++ b/exchanges/exchange.go @@ -982,8 +982,7 @@ func (b *Base) SupportsAsset(a asset.Item) bool { // PrintEnabledPairs prints the exchanges enabled asset pairs func (b *Base) PrintEnabledPairs() { for k, v := range b.CurrencyPairs.Pairs { - log.Infof(log.ExchangeSys, "%s Asset type %v:\n\t Enabled pairs: %v", - b.Name, strings.ToUpper(k.String()), v.Enabled) + log.Infof(log.ExchangeSys, "%s Asset type %v:\n\t Enabled pairs: %v", b.Name, strings.ToUpper(k.String()), v.Enabled) } } @@ -994,10 +993,7 @@ func (b *Base) GetBase() *Base { return b } // for validation of API credentials func (b *Base) CheckTransientError(err error) error { if _, ok := err.(net.Error); ok { - log.Warnf(log.ExchangeSys, - "%s net error captured, will not disable authentication %s", - b.Name, - err) + log.Warnf(log.ExchangeSys, "%s net error captured, will not disable authentication %s", b.Name, err) return nil } return err diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index 0d42f5e4492..043fe9471fc 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -33,7 +33,7 @@ import ( const ( apiKey = "" apiSecret = "" - canManipulateRealOrders = false + canManipulateRealOrders = true ) var g = &Gateio{} diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 8e2a3def1d4..629c78f4509 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -84,6 +84,12 @@ func (g *Gateio) WsConnectSpot(ctx context.Context, conn stream.Connection) erro return nil } +// AuthenticateSpot sends an authentication message to the websocket connection +func (g *Gateio) AuthenticateSpot(ctx context.Context, conn stream.Connection) error { + _, err := g.WebsocketLogin(ctx, conn, "spot.login") + return err +} + func (g *Gateio) generateWsSignature(secret, event, channel string, dtime time.Time) (string, error) { msg := "channel=" + channel + "&event=" + event + "&time=" + strconv.FormatInt(dtime.Unix(), 10) mac := hmac.New(sha512.New, []byte(secret)) @@ -96,9 +102,11 @@ func (g *Gateio) generateWsSignature(secret, event, channel string, dtime time.T // WsHandleSpotData handles spot data func (g *Gateio) WsHandleSpotData(_ context.Context, respRaw []byte) error { if requestID, err := jsonparser.GetString(respRaw, "request_id"); err == nil && requestID != "" { + fmt.Println("HANDLE: ", string(respRaw)) if !g.Websocket.Match.IncomingWithData(requestID, respRaw) { return fmt.Errorf("gateio_websocket.go error - unable to match requestID %v", requestID) } + return nil } var push WsResponse diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 3e79b24ee13..7537bb7d3bb 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -214,6 +214,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { GenerateSubscriptions: g.GenerateDefaultSubscriptionsSpot, Connector: g.WsConnectSpot, AllowOutbound: true, + Authenticate: g.AuthenticateSpot, }) if err != nil { return err diff --git a/exchanges/gateio/websocket_request.go b/exchanges/gateio/websocket_request.go index f07602019c0..e161f02c903 100644 --- a/exchanges/gateio/websocket_request.go +++ b/exchanges/gateio/websocket_request.go @@ -12,13 +12,14 @@ import ( "github.com/buger/jsonparser" "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/order" + "github.com/thrasher-corp/gocryptotrader/exchanges/stream" ) // WebsocketRequest defines a websocket request type WebsocketRequest struct { - App string `json:"app,omitempty"` Time int64 `json:"time,omitempty"` ID int64 `json:"id,omitempty"` Channel string `json:"channel"` @@ -28,33 +29,45 @@ type WebsocketRequest struct { // WebsocketPayload defines an individualised websocket payload type WebsocketPayload struct { - APIKey string `json:"api_key,omitempty"` - Signature string `json:"signature,omitempty"` - Timestamp string `json:"timestamp,omitempty"` - RequestID string `json:"req_id,omitempty"` - RequestParam []byte `json:"req_param,omitempty"` + RequestID string `json:"req_id,omitempty"` + // APIKey and signature are only required in the initial login request + // which is done when the connection is established. + APIKey string `json:"api_key,omitempty"` + Timestamp string `json:"timestamp,omitempty"` + Signature string `json:"signature,omitempty"` + RequestParam json.RawMessage `json:"req_param,omitempty"` } -// WebsocketResponse defines a websocket response -type WebsocketResponse struct { - RequestID string `json:"req_id"` - Acknowleged bool `json:"ack"` - Header WebsocketHeader `json:"req_header"` - Data json.RawMessage `json:"data"` -} +// // WebsocketResponse defines a websocket response +// type WebsocketResponse struct { +// RequestID string `json:"req_id"` +// APIKey string `json:"api_key"` +// Timestamp string `json:"timestamp"` +// Signature string `json:"signature"` +// TraceID string `json:"trace_id"` +// RequestHeader struct { +// TraceID string `json:"trace_id"` +// } `json:"req_header"` +// Acknowleged bool `json:"ack"` +// // Header WebsocketHeader `json:"header"` +// RequestParam json.RawMessage `json:"req_param"` +// Data json.RawMessage `json:"data"` +// } -// WebsocketHeader defines a websocket header -type WebsocketHeader struct { - ResponseTime int64 `json:"response_time"` - Status string `json:"status"` - Channel string `json:"channel"` - Event string `json:"event"` - ClientID string `json:"client_id"` -} +// // WebsocketHeader defines a websocket header +// type WebsocketHeader struct { +// ResponseTime int64 `json:"response_time"` +// Status string `json:"status"` +// Channel string `json:"channel"` +// Event string `json:"event"` +// ClientID string `json:"client_id"` +// } type WebsocketErrors struct { - Label string `json:"label"` - Message string `json:"message"` + Errors struct { + Label string `json:"label"` + Message string `json:"message"` + } `json:"errs"` } // GetRoute returns the route for a websocket request, this is a POC @@ -68,22 +81,79 @@ func (g *Gateio) GetRoute(a asset.Item) (string, error) { } } +type Header struct { + ResponseTime Time `json:"response_time"` + Status string `json:"status"` + Channel string `json:"channel"` + Event string `json:"event"` + ClientID string `json:"client_id"` + ConnectionID string `json:"conn_id"` + TraceID string `json:"trace_id"` +} + // LoginResult defines a login result type LoginResult struct { APIKey string `json:"api_key"` UID string `json:"uid"` } +type LoginResponse struct { + Header Header `json:"header"` + Data json.RawMessage `json:"data"` +} + // WebsocketLogin logs in to the websocket -func (g *Gateio) WebsocketLogin(ctx context.Context, a asset.Item) (*LoginResult, error) { +func (g *Gateio) WebsocketLogin(ctx context.Context, conn stream.Connection, channel string) (*LoginResult, error) { + creds, err := g.GetCredentials(ctx) + if err != nil { + return nil, err + } - route, err := g.GetRoute(a) + tn := time.Now() + msg := "api\n" + channel + "\n" + string([]byte(nil)) + "\n" + strconv.FormatInt(tn.Unix(), 10) + mac := hmac.New(sha512.New, []byte(creds.Secret)) + if _, err := mac.Write([]byte(msg)); err != nil { + return nil, err + } + signature := hex.EncodeToString(mac.Sum(nil)) + + outbound := WebsocketRequest{ + Time: tn.Unix(), + Channel: channel, + Event: "api", + Payload: WebsocketPayload{ + RequestID: strconv.FormatInt(tn.UnixNano(), 10), + APIKey: creds.Key, + Signature: signature, + Timestamp: strconv.FormatInt(tn.Unix(), 10), + }, + } + + resp, err := conn.SendMessageReturnResponse(outbound.Payload.RequestID, outbound) + if err != nil { + return nil, err + } + + var inbound LoginResponse + err = json.Unmarshal(resp, &inbound) if err != nil { return nil, err } - var resp *LoginResult - err = g.SendWebsocketRequest(ctx, "spot.login", route, nil, &resp) - return resp, err + + if fail, dataType, _, _ := jsonparser.Get(inbound.Data, "errs"); dataType != jsonparser.NotExist { + var wsErr WebsocketErrors + err := json.Unmarshal(fail, &wsErr.Errors) + if err != nil { + return nil, err + } + return nil, fmt.Errorf("gateio websocket error: %s %s", wsErr.Errors.Label, wsErr.Errors.Message) + } + + var result struct { + Result LoginResult `json:"result"` + } + err = json.Unmarshal(inbound.Data, &result) + return &result.Result, err } // OrderPlace @@ -93,56 +163,146 @@ func (g *Gateio) WebsocketLogin(ctx context.Context, a asset.Item) (*LoginResult // OrderAmend // OrderStatus -func (g *Gateio) OrderPlace(ctx context.Context, batch []order.Submit) string { - return "" +// WebsocketOrder defines a websocket order +type WebsocketOrder struct { + Text string `json:"text"` + CurrencyPair string `json:"currency_pair,omitempty"` + Type string `json:"type,omitempty"` + Account string `json:"account,omitempty"` + Side string `json:"side,omitempty"` + Amount string `json:"amount,omitempty"` + Price string `json:"price,omitempty"` + TimeInForce string `json:"time_in_force,omitempty"` + Iceberg string `json:"iceberg,omitempty"` + AutoBorrow bool `json:"auto_borrow,omitempty"` + AutoRepay bool `json:"auto_repay,omitempty"` + StpAct string `json:"stp_act,omitempty"` } -// SendWebsocketRequest sends a websocket request to the exchange -func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel, route string, params []byte, result any) error { - creds, err := g.GetCredentials(ctx) +type WebscocketOrderResponse struct { + ReqID string `json:"req_id"` + RequestParam any `json:"req_param"` + APIKey string `json:"api_key"` + Timestamp string `json:"timestamp"` + Signature string `json:"signature"` +} + +type WebsocketOrderParamResponse struct { + Text string `json:"text"` + CurrencyPair string `json:"currency_pair"` + Type string `json:"type"` + Account string `json:"account"` + Side string `json:"side"` + Amount string `json:"amount"` + Price string `json:"price"` +} + +var errBatchSliceEmpty = fmt.Errorf("batch cannot be empty") + +// WebsocketOrderPlace places an order via the websocket connection. You can +// send multiple orders in a single request. But only for one asset route. +// So this can only batch spot orders or futures orders, not both. +func (g *Gateio) WebsocketOrderPlace(ctx context.Context, batch []WebsocketOrder, a asset.Item) ([]WebsocketOrderParamResponse, error) { + if len(batch) == 0 { + return nil, errBatchSliceEmpty + } + + for i := range batch { + if batch[i].Text == "" { + // For some reason the API requires a text field, or it will be + // rejected in the second response. This is a workaround. + batch[i].Text = "t-" + strconv.FormatInt(time.Now().UnixNano(), 10) + } + if batch[i].CurrencyPair == "" { + return nil, currency.ErrCurrencyPairEmpty + } + if batch[i].Side == "" { + return nil, order.ErrSideIsInvalid + } + if batch[i].Amount == "" { + return nil, errInvalidAmount + } + if batch[i].Type == "limit" && batch[i].Price == "" { + return nil, errInvalidPrice + } + } + + route, err := g.GetRoute(a) if err != nil { - return err + return nil, err + } + + var resp WebscocketOrderResponse + if len(batch) == 1 { + var incoming WebsocketOrderParamResponse + resp.RequestParam = &incoming + + batchBytes, err := json.Marshal(batch[0]) + if err != nil { + return nil, err + } + + err = g.SendWebsocketRequest(ctx, "spot.order_place", route, batchBytes, &resp) + return []WebsocketOrderParamResponse{incoming}, err } + var incoming []WebsocketOrderParamResponse + resp.RequestParam = &incoming + err = g.SendWebsocketRequest(ctx, "spot.order_place", route, []byte{}, &resp) + return incoming, err +} + +// SendWebsocketRequest sends a websocket request to the exchange +func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel, route string, params json.RawMessage, result any) error { + // creds, err := g.GetCredentials(ctx) + // if err != nil { + // return err + // } + tn := time.Now() - msg := "api\n" + channel + "\n" + string(params) + "\n" + strconv.FormatInt(tn.Unix(), 10) - mac := hmac.New(sha512.New, []byte(creds.Secret)) - if _, err := mac.Write([]byte(msg)); err != nil { - return err + // msg := "api\n" + channel + "\n" + string(params) + "\n" + strconv.FormatInt(tn.Unix(), 10) + // mac := hmac.New(sha512.New, []byte(creds.Secret)) + // if _, err := mac.Write([]byte(msg)); err != nil { + // return err + // } + // signature := hex.EncodeToString(mac.Sum(nil)) + + mainPayload := WebsocketPayload{ + // This request ID associated with the payload is the match to the + // response. + RequestID: strconv.FormatInt(tn.UnixNano(), 10), + // APIKey: creds.Key, + RequestParam: params, + // Signature: signature, + Timestamp: strconv.FormatInt(tn.Unix(), 10), } - signature := hex.EncodeToString(mac.Sum(nil)) outbound := WebsocketRequest{ Time: tn.Unix(), Channel: channel, Event: "api", - Payload: WebsocketPayload{ - RequestID: strconv.FormatInt(tn.UnixNano(), 10), - APIKey: creds.Key, - RequestParam: params, - Signature: signature, - Timestamp: strconv.FormatInt(tn.Unix(), 10), - }, + Payload: mainPayload, } - var inbound WebsocketResponse - err = g.Websocket.SendRequest(ctx, route, outbound.Payload.RequestID, &outbound, &inbound) + var inbound GeneralWebsocketResponse + err := g.Websocket.SendRequest(ctx, route, outbound.Payload.RequestID, &outbound, &inbound) if err != nil { return err } - if fail, dataType, _, _ := jsonparser.Get(inbound.Data, "errs"); dataType != jsonparser.NotExist { + + if inbound.Header.Status != "200" { var wsErr WebsocketErrors - err := json.Unmarshal(fail, &wsErr) + err = json.Unmarshal(inbound.Data, &wsErr) if err != nil { return err } - return fmt.Errorf("gateio websocket error: %s %s", wsErr.Label, wsErr.Message) + return fmt.Errorf("gateio websocket error: %s %s", wsErr.Errors.Label, wsErr.Errors.Message) } - nested, _, _, err := jsonparser.Get(inbound.Data, "result") - if err != nil { - return err - } + return json.Unmarshal(inbound.Data, result) +} - return json.Unmarshal(nested, result) +type GeneralWebsocketResponse struct { + Header Header `json:"header"` + Data json.RawMessage `json:"data"` } diff --git a/exchanges/gateio/websocket_request_test.go b/exchanges/gateio/websocket_request_test.go index dc78337e251..26e73825a72 100644 --- a/exchanges/gateio/websocket_request_test.go +++ b/exchanges/gateio/websocket_request_test.go @@ -3,16 +3,23 @@ package gateio import ( "context" "fmt" + "strings" "testing" + "time" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" + "github.com/thrasher-corp/gocryptotrader/exchanges/order" + "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" ) +var loginResponse = []byte(`{"header":{"response_time":"1722227146659","status":"200","channel":"spot.login","event":"api","client_id":"14.203.57.50-0xc11df96f20"},"data":{"result":{"api_key":"4960099442600b4cfefa48ac72dacca0","uid":"2365748"}},"request_id":"1722227146427268900"}`) + func TestWebsocketLogin(t *testing.T) { t.Parallel() - _, err := g.WebsocketLogin(context.Background(), asset.Futures) + _, err := g.WebsocketLogin(context.Background(), nil, "bro.Login") require.ErrorIs(t, err, common.ErrNotYetImplemented) require.NoError(t, g.UpdateTradablePairs(context.Background(), false)) @@ -28,8 +35,56 @@ func TestWebsocketLogin(t *testing.T) { g.GetBase().API.AuthenticatedSupport = true g.GetBase().API.AuthenticatedWebsocketSupport = true - got, err := g.WebsocketLogin(context.Background(), asset.Spot) + got, err := g.WebsocketLogin(context.Background(), nil, "bro.Login") require.NoError(t, err) fmt.Println(got) } + +var orderError = []byte(`{"header":{"response_time":"1722392009059","status":"400","channel":"spot.order_place","event":"api","client_id":"14.203.57.50-0xc0b61a0840","conn_id":"b5cd175a189984a6","trace_id":"f56a31478d7c6ce4ddaea3b337263233"},"data":{"errs":{"label":"INVALID_ARGUMENT","message":"OrderPlace request params error"}},"request_id":"1722392008842968100"}`) +var orderAcceptedResp = []byte(`{"header":{"response_time":"1722393719499","status":"200","channel":"spot.order_place","event":"api","client_id":"14.203.57.50-0xc213dab340","conn_id":"bfcbe154b8520050","trace_id":"74fbfd701d54bfe207ec79b6d2736b3a"},"data":{"result":{"req_id":"1722393719287158300","api_key":"","timestamp":"","signature":"","trace_id":"0e30c04e4e7499bccde8f83990d7168a","req_header":{"trace_id":"0e30c04e4e7499bccde8f83990d7168a"},"req_param":[{"text":"apiv4-ws","currency_pair":"BTC_USDT","type":"limit","side":"BUY","amount":"-1","price":"-1"}]}},"request_id":"1722393719287158300","ack":true}`) + +func TestWebsocketOrderPlace(t *testing.T) { + t.Parallel() + _, err := g.WebsocketOrderPlace(context.Background(), nil, 0) + require.ErrorIs(t, err, errBatchSliceEmpty) + _, err = g.WebsocketOrderPlace(context.Background(), make([]WebsocketOrder, 1), 0) + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + out := WebsocketOrder{CurrencyPair: "BTC_USDT"} + _, err = g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out}, 0) + require.ErrorIs(t, err, order.ErrSideIsInvalid) + out.Side = strings.ToLower(order.Buy.String()) + _, err = g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out}, 0) + require.ErrorIs(t, err, errInvalidAmount) + out.Amount = "-1" + out.Type = "limit" + _, err = g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out}, 0) + require.ErrorIs(t, err, errInvalidPrice) + out.Price = "-1" + _, err = g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out}, 0) + require.ErrorIs(t, err, common.ErrNotYetImplemented) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + require.NoError(t, g.UpdateTradablePairs(context.Background(), false)) + for _, a := range g.GetAssetTypes(true) { + avail, err := g.GetAvailablePairs(a) + require.NoError(t, err) + if len(avail) > 1 { + avail = avail[:1] + } + require.NoError(t, g.SetPairs(avail, a, true)) + } + require.NoError(t, g.Websocket.Connect()) + g.GetBase().API.AuthenticatedSupport = true + g.GetBase().API.AuthenticatedWebsocketSupport = true + + out.Account = "spot" + _, err = g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out}, asset.Spot) + require.NoError(t, err) + + // _, err = g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out, out}, asset.Spot) + // require.NoError(t, err) + + time.Sleep(time.Second * 5) +} diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index a2ced830a6e..c02e93694ce 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -67,6 +67,11 @@ type ConnectionSetup struct { // send messages to the exchange's websocket server. This will allow the // connection to be established without subscriptions needing to be made. AllowOutbound bool + // Authenticate is a function that will be called to authenticate the + // connection to the exchange's websocket server. This function should + // handle the authentication process and return an error if the + // authentication fails. + Authenticate func(ctx context.Context, conn Connection) error } // ConnectionWrapper contains the connection setup details to be used when diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 4883fe783b7..6fe18b4331b 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -426,6 +426,20 @@ func (w *Websocket) Connect() error { w.Wg.Add(1) go w.Reader(context.TODO(), conn, w.connectionManager[i].Setup.Handler) + if w.connectionManager[i].Setup.Authenticate != nil && w.CanUseAuthenticatedEndpoints() { + fmt.Println("Authenticating") + err = w.connectionManager[i].Setup.Authenticate(context.TODO(), conn) + if err != nil { + fmt.Println("Error authenticating", err) + } else { + fmt.Println("Authenticated") + } + } + + for _, sub := range subs { + fmt.Printf("Subscribing to %+v\n", sub) + } + err = w.connectionManager[i].Setup.Subscriber(context.TODO(), conn, subs) if err != nil { multiConnectFatalError = fmt.Errorf("%v Error subscribing %w", w.exchangeName, err) From 9bb6ef41e91654464aa0471788ac068c6d598481 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 31 Jul 2024 16:37:51 +1000 Subject: [PATCH 040/138] gateio: still building out --- exchanges/gateio/websocket_request.go | 332 +++++++++----------- exchanges/gateio/websocket_request_test.go | 97 +++++- exchanges/gateio/websocket_request_types.go | 120 +++++++ exchanges/stream/stream_types.go | 5 +- exchanges/stream/websocket.go | 72 +++-- exchanges/stream/websocket_connection.go | 8 +- 6 files changed, 421 insertions(+), 213 deletions(-) create mode 100644 exchanges/gateio/websocket_request_types.go diff --git a/exchanges/gateio/websocket_request.go b/exchanges/gateio/websocket_request.go index e161f02c903..2e0ccd9a700 100644 --- a/exchanges/gateio/websocket_request.go +++ b/exchanges/gateio/websocket_request.go @@ -8,9 +8,9 @@ import ( "encoding/json" "fmt" "strconv" + "strings" "time" - "github.com/buger/jsonparser" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" @@ -18,61 +18,12 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/stream" ) -// WebsocketRequest defines a websocket request -type WebsocketRequest struct { - Time int64 `json:"time,omitempty"` - ID int64 `json:"id,omitempty"` - Channel string `json:"channel"` - Event string `json:"event"` - Payload WebsocketPayload `json:"payload"` -} - -// WebsocketPayload defines an individualised websocket payload -type WebsocketPayload struct { - RequestID string `json:"req_id,omitempty"` - // APIKey and signature are only required in the initial login request - // which is done when the connection is established. - APIKey string `json:"api_key,omitempty"` - Timestamp string `json:"timestamp,omitempty"` - Signature string `json:"signature,omitempty"` - RequestParam json.RawMessage `json:"req_param,omitempty"` -} - -// // WebsocketResponse defines a websocket response -// type WebsocketResponse struct { -// RequestID string `json:"req_id"` -// APIKey string `json:"api_key"` -// Timestamp string `json:"timestamp"` -// Signature string `json:"signature"` -// TraceID string `json:"trace_id"` -// RequestHeader struct { -// TraceID string `json:"trace_id"` -// } `json:"req_header"` -// Acknowleged bool `json:"ack"` -// // Header WebsocketHeader `json:"header"` -// RequestParam json.RawMessage `json:"req_param"` -// Data json.RawMessage `json:"data"` -// } - -// // WebsocketHeader defines a websocket header -// type WebsocketHeader struct { -// ResponseTime int64 `json:"response_time"` -// Status string `json:"status"` -// Channel string `json:"channel"` -// Event string `json:"event"` -// ClientID string `json:"client_id"` -// } - -type WebsocketErrors struct { - Errors struct { - Label string `json:"label"` - Message string `json:"message"` - } `json:"errs"` -} +var errBatchSliceEmpty = fmt.Errorf("batch cannot be empty") +var errNoOrdersToCancel = fmt.Errorf("no orders to cancel") -// GetRoute returns the route for a websocket request, this is a POC +// GetWebsocketRoute returns the route for a websocket request, this is a POC // for the websocket wrapper. -func (g *Gateio) GetRoute(a asset.Item) (string, error) { +func (g *Gateio) GetWebsocketRoute(a asset.Item) (string, error) { switch a { case asset.Spot: return gateioWebsocketEndpoint, nil @@ -81,128 +32,58 @@ func (g *Gateio) GetRoute(a asset.Item) (string, error) { } } -type Header struct { - ResponseTime Time `json:"response_time"` - Status string `json:"status"` - Channel string `json:"channel"` - Event string `json:"event"` - ClientID string `json:"client_id"` - ConnectionID string `json:"conn_id"` - TraceID string `json:"trace_id"` -} - -// LoginResult defines a login result -type LoginResult struct { - APIKey string `json:"api_key"` - UID string `json:"uid"` -} - -type LoginResponse struct { - Header Header `json:"header"` - Data json.RawMessage `json:"data"` -} - -// WebsocketLogin logs in to the websocket -func (g *Gateio) WebsocketLogin(ctx context.Context, conn stream.Connection, channel string) (*LoginResult, error) { +// WebsocketLogin authenticates the websocket connection +func (g *Gateio) WebsocketLogin(ctx context.Context, conn stream.Connection, channel string) (*WebsocketLoginResponse, error) { creds, err := g.GetCredentials(ctx) if err != nil { return nil, err } tn := time.Now() - msg := "api\n" + channel + "\n" + string([]byte(nil)) + "\n" + strconv.FormatInt(tn.Unix(), 10) + msg := "api\n" + channel + "\n" + "\n" + strconv.FormatInt(tn.Unix(), 10) mac := hmac.New(sha512.New, []byte(creds.Secret)) if _, err := mac.Write([]byte(msg)); err != nil { return nil, err } signature := hex.EncodeToString(mac.Sum(nil)) - outbound := WebsocketRequest{ - Time: tn.Unix(), - Channel: channel, - Event: "api", - Payload: WebsocketPayload{ - RequestID: strconv.FormatInt(tn.UnixNano(), 10), - APIKey: creds.Key, - Signature: signature, - Timestamp: strconv.FormatInt(tn.Unix(), 10), - }, + payload := WebsocketPayload{ + RequestID: strconv.FormatInt(tn.UnixNano(), 10), + APIKey: creds.Key, + Signature: signature, + Timestamp: strconv.FormatInt(tn.Unix(), 10), } - resp, err := conn.SendMessageReturnResponse(outbound.Payload.RequestID, outbound) + request := WebsocketRequest{Time: tn.Unix(), Channel: channel, Event: "api", Payload: payload} + + resp, err := conn.SendMessageReturnResponse(ctx, request.Payload.RequestID, request) if err != nil { return nil, err } - var inbound LoginResponse + var inbound WebsocketAPIResponse err = json.Unmarshal(resp, &inbound) if err != nil { return nil, err } - if fail, dataType, _, _ := jsonparser.Get(inbound.Data, "errs"); dataType != jsonparser.NotExist { + if inbound.Header.Status != "200" { var wsErr WebsocketErrors - err := json.Unmarshal(fail, &wsErr.Errors) + err := json.Unmarshal(inbound.Data, &wsErr.Errors) if err != nil { return nil, err } - return nil, fmt.Errorf("gateio websocket error: %s %s", wsErr.Errors.Label, wsErr.Errors.Message) + return nil, fmt.Errorf("%s: %s", wsErr.Errors.Label, wsErr.Errors.Message) } - var result struct { - Result LoginResult `json:"result"` - } - err = json.Unmarshal(inbound.Data, &result) - return &result.Result, err + var result WebsocketLoginResponse + return &result, json.Unmarshal(inbound.Data, &result) } -// OrderPlace -// OrderCancel -// OrderCancelAllByIDList -// OrderCancelAllByPair -// OrderAmend -// OrderStatus - -// WebsocketOrder defines a websocket order -type WebsocketOrder struct { - Text string `json:"text"` - CurrencyPair string `json:"currency_pair,omitempty"` - Type string `json:"type,omitempty"` - Account string `json:"account,omitempty"` - Side string `json:"side,omitempty"` - Amount string `json:"amount,omitempty"` - Price string `json:"price,omitempty"` - TimeInForce string `json:"time_in_force,omitempty"` - Iceberg string `json:"iceberg,omitempty"` - AutoBorrow bool `json:"auto_borrow,omitempty"` - AutoRepay bool `json:"auto_repay,omitempty"` - StpAct string `json:"stp_act,omitempty"` -} - -type WebscocketOrderResponse struct { - ReqID string `json:"req_id"` - RequestParam any `json:"req_param"` - APIKey string `json:"api_key"` - Timestamp string `json:"timestamp"` - Signature string `json:"signature"` -} - -type WebsocketOrderParamResponse struct { - Text string `json:"text"` - CurrencyPair string `json:"currency_pair"` - Type string `json:"type"` - Account string `json:"account"` - Side string `json:"side"` - Amount string `json:"amount"` - Price string `json:"price"` -} - -var errBatchSliceEmpty = fmt.Errorf("batch cannot be empty") - // WebsocketOrderPlace places an order via the websocket connection. You can // send multiple orders in a single request. But only for one asset route. // So this can only batch spot orders or futures orders, not both. -func (g *Gateio) WebsocketOrderPlace(ctx context.Context, batch []WebsocketOrder, a asset.Item) ([]WebsocketOrderParamResponse, error) { +func (g *Gateio) WebsocketOrderPlace(ctx context.Context, batch []WebsocketOrder, a asset.Item) ([]WebsocketOrderResponse, error) { if len(batch) == 0 { return nil, errBatchSliceEmpty } @@ -211,7 +92,9 @@ func (g *Gateio) WebsocketOrderPlace(ctx context.Context, batch []WebsocketOrder if batch[i].Text == "" { // For some reason the API requires a text field, or it will be // rejected in the second response. This is a workaround. - batch[i].Text = "t-" + strconv.FormatInt(time.Now().UnixNano(), 10) + // +1 index for uniqueness in batch, when clock hasn't updated yet. + // TODO: Remove and use common counter. + batch[i].Text = "t-" + strconv.FormatInt(time.Now().UnixNano()+int64(i), 10) } if batch[i].CurrencyPair == "" { return nil, currency.ErrCurrencyPairEmpty @@ -227,65 +110,148 @@ func (g *Gateio) WebsocketOrderPlace(ctx context.Context, batch []WebsocketOrder } } - route, err := g.GetRoute(a) + route, err := g.GetWebsocketRoute(a) if err != nil { return nil, err } - var resp WebscocketOrderResponse if len(batch) == 1 { - var incoming WebsocketOrderParamResponse - resp.RequestParam = &incoming - - batchBytes, err := json.Marshal(batch[0]) + singleOutbound, err := json.Marshal(batch[0]) if err != nil { return nil, err } + var singleResponse WebsocketOrderResponse + err = g.SendWebsocketRequest(ctx, "spot.order_place", route, singleOutbound, &singleResponse, 2) + return []WebsocketOrderResponse{singleResponse}, err + } + + multiOutbound, err := json.Marshal(batch) + if err != nil { + return nil, err + } + var resp []WebsocketOrderResponse + err = g.SendWebsocketRequest(ctx, "spot.order_place", route, multiOutbound, &resp, 2) + return resp, err +} + +// WebsocketOrderCancel cancels an order via the websocket connection +func (g *Gateio) WebsocketOrderCancel(ctx context.Context, orderID string, pair currency.Pair, account string, a asset.Item) (*WebsocketOrderResponse, error) { + if orderID == "" { + return nil, order.ErrOrderIDNotSet + } + if pair.IsEmpty() { + return nil, currency.ErrCurrencyPairEmpty + } + route, err := g.GetWebsocketRoute(a) + if err != nil { + return nil, err + } + + out := struct { + OrderID string `json:"order_id"` // This requires order_id tag + Pair string `json:"pair"` + Account string `json:"account,omitempty"` + }{ + OrderID: orderID, + Pair: pair.String(), + Account: account, + } + outbound, err := json.Marshal(out) + if err != nil { + return nil, err + } + var resp WebsocketOrderResponse + err = g.SendWebsocketRequest(ctx, "spot.order_cancel", route, outbound, &resp, 1) + return &resp, err +} + +type WebsocketCancellAllResponse struct { + Pair currency.Pair `json:"currency_pair"` + Label string `json:"label"` + Message string `json:"message"` + Succeeded bool `json:"succeeded"` +} + +// WebsocketOrderCancelAllByIDs cancels multiple orders via the websocket +func (g *Gateio) WebsocketOrderCancelAllByIDs(ctx context.Context, o []WebsocketOrderCancelRequest, a asset.Item) ([]WebsocketCancellAllResponse, error) { + if len(o) == 0 { + return nil, errNoOrdersToCancel + } + + for i := range o { + if o[i].OrderID == "" { + return nil, order.ErrOrderIDNotSet + } + if o[i].Pair.IsEmpty() { + return nil, currency.ErrCurrencyPairEmpty + } + } - err = g.SendWebsocketRequest(ctx, "spot.order_place", route, batchBytes, &resp) - return []WebsocketOrderParamResponse{incoming}, err + route, err := g.GetWebsocketRoute(a) + if err != nil { + return nil, err } - var incoming []WebsocketOrderParamResponse - resp.RequestParam = &incoming - err = g.SendWebsocketRequest(ctx, "spot.order_place", route, []byte{}, &resp) - return incoming, err + outbound, err := json.Marshal(o) + if err != nil { + return nil, err + } + + var resp []WebsocketCancellAllResponse + err = g.SendWebsocketRequest(ctx, "spot.order_cancel_ids", route, outbound, &resp, 2) + return resp, err } +// OrderCancelAllByIDList +// OrderCancelAllByPair +// OrderAmend +// OrderStatus + // SendWebsocketRequest sends a websocket request to the exchange -func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel, route string, params json.RawMessage, result any) error { - // creds, err := g.GetCredentials(ctx) - // if err != nil { - // return err - // } +func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel, route string, params json.RawMessage, result any, expectedResponses int) error { + conn, err := g.Websocket.GetOutboundConnection(route) + if err != nil { + return err + } tn := time.Now() - // msg := "api\n" + channel + "\n" + string(params) + "\n" + strconv.FormatInt(tn.Unix(), 10) - // mac := hmac.New(sha512.New, []byte(creds.Secret)) - // if _, err := mac.Write([]byte(msg)); err != nil { - // return err - // } - // signature := hex.EncodeToString(mac.Sum(nil)) - mainPayload := WebsocketPayload{ // This request ID associated with the payload is the match to the // response. - RequestID: strconv.FormatInt(tn.UnixNano(), 10), - // APIKey: creds.Key, + RequestID: strconv.FormatInt(tn.UnixNano(), 10), RequestParam: params, - // Signature: signature, - Timestamp: strconv.FormatInt(tn.Unix(), 10), + Timestamp: strconv.FormatInt(tn.Unix(), 10), } - outbound := WebsocketRequest{ + request := WebsocketRequest{ Time: tn.Unix(), Channel: channel, Event: "api", Payload: mainPayload, } - var inbound GeneralWebsocketResponse - err := g.Websocket.SendRequest(ctx, route, outbound.Payload.RequestID, &outbound, &inbound) + out, _ := json.Marshal(request) + + fmt.Println("outbound:", string(out)) + + responses, err := conn.SendMessageReturnResponses(ctx, request.Payload.RequestID, request, expectedResponses, InspectPayloadForAck) + if err != nil { + return err + } + + if len(responses) == 0 { + return fmt.Errorf("no responses received") + } + + var inbound WebsocketAPIResponse + // The last response is the one we want to unmarshal, the other is just + // an ack. If the request fails on the ACK then we can unmarshal the error + // from that as the next response won't come anyway. + endResponse := responses[len(responses)-1] + + fmt.Println("response:", string(endResponse)) + + err = json.Unmarshal(endResponse, &inbound) if err != nil { return err } @@ -296,13 +262,21 @@ func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel, route string if err != nil { return err } - return fmt.Errorf("gateio websocket error: %s %s", wsErr.Errors.Label, wsErr.Errors.Message) + return fmt.Errorf("%s: %s", wsErr.Errors.Label, wsErr.Errors.Message) + } + + to := struct { + Result any `json:"result"` + }{ + Result: result, } - return json.Unmarshal(inbound.Data, result) + return json.Unmarshal(inbound.Data, &to) } -type GeneralWebsocketResponse struct { - Header Header `json:"header"` - Data json.RawMessage `json:"data"` +// InspectPayloadForAck checks the payload for an ack, it returns true if the +// payload does not contain an ack. This will force the cancellation of further +// waiting for responses. +func InspectPayloadForAck(data []byte) bool { + return !strings.Contains(string(data), "ack") } diff --git a/exchanges/gateio/websocket_request_test.go b/exchanges/gateio/websocket_request_test.go index 26e73825a72..c74a10f29d8 100644 --- a/exchanges/gateio/websocket_request_test.go +++ b/exchanges/gateio/websocket_request_test.go @@ -5,7 +5,6 @@ import ( "fmt" "strings" "testing" - "time" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" @@ -43,6 +42,10 @@ func TestWebsocketLogin(t *testing.T) { var orderError = []byte(`{"header":{"response_time":"1722392009059","status":"400","channel":"spot.order_place","event":"api","client_id":"14.203.57.50-0xc0b61a0840","conn_id":"b5cd175a189984a6","trace_id":"f56a31478d7c6ce4ddaea3b337263233"},"data":{"errs":{"label":"INVALID_ARGUMENT","message":"OrderPlace request params error"}},"request_id":"1722392008842968100"}`) var orderAcceptedResp = []byte(`{"header":{"response_time":"1722393719499","status":"200","channel":"spot.order_place","event":"api","client_id":"14.203.57.50-0xc213dab340","conn_id":"bfcbe154b8520050","trace_id":"74fbfd701d54bfe207ec79b6d2736b3a"},"data":{"result":{"req_id":"1722393719287158300","api_key":"","timestamp":"","signature":"","trace_id":"0e30c04e4e7499bccde8f83990d7168a","req_header":{"trace_id":"0e30c04e4e7499bccde8f83990d7168a"},"req_param":[{"text":"apiv4-ws","currency_pair":"BTC_USDT","type":"limit","side":"BUY","amount":"-1","price":"-1"}]}},"request_id":"1722393719287158300","ack":true}`) +var orderSecondResponseError = []byte(`{"header":{"response_time":"1722400001367","status":"400","channel":"spot.order_place","event":"api","client_id":"14.203.57.50-0xc12e5e4f20","conn_id":"4ddf3b1b45523bc3","trace_id":"8cca91e29b405e334b1901463c36afe1"},"data":{"errs":{"label":"INVALID_PARAM_VALUE","message":"label: INVALID_PARAM_VALUE, message: Your order size 0.200000 USDT is too small. The minimum is 3 USDT"}},"request_id":"1722400001142974600"}`) +var orderSecondResponseSuccess = []byte(`{"header":{"response_time":"1722400187811","status":"200","channel":"spot.order_place","event":"api","client_id":"14.203.57.50-0xc1b81a7340"},"data":{"result":{"left":"0.0003","update_time":"1722400187","amount":"0.0003","create_time":"1722400187","price":"20000","finish_as":"open","time_in_force":"gtc","currency_pair":"BTC_USDT","type":"limit","account":"spot","side":"buy","amend_text":"-","text":"t-1722400187564025900","status":"open","iceberg":"0","filled_total":"0","id":"644865690097","fill_price":"0","update_time_ms":1722400187807,"create_time_ms":1722400187807}},"request_id":"1722400187564025900"}`) +var orderBatchSuccess = []byte(`{"header":{"response_time":"1722402442822","status":"200","channel":"spot.order_place","event":"api","client_id":"14.203.57.50-0xc0e372e580"},"data":{"result":[{"account":"spot","status":"open","side":"buy","amount":"0.0003","id":"644883514616","create_time":"1722402442","update_time":"1722402442","text":"t-1722402442588484600","left":"0.0003","currency_pair":"BTC_USDT","type":"limit","finish_as":"open","price":"20000","time_in_force":"gtc","iceberg":"0","filled_total":"0","fill_price":"0","create_time_ms":1722402442819,"update_time_ms":1722402442819,"succeeded":true},{"account":"spot","status":"open","side":"buy","amount":"0.0003","id":"644883514625","create_time":"1722402442","update_time":"1722402442","text":"t-1722402442588484601","left":"0.0003","currency_pair":"BTC_USDT","type":"limit","finish_as":"open","price":"20000","time_in_force":"gtc","iceberg":"0","filled_total":"0","fill_price":"0","create_time_ms":1722402442821,"update_time_ms":1722402442821,"succeeded":true}]},"request_id":"172240244 +2588484600"}`) func TestWebsocketOrderPlace(t *testing.T) { t.Parallel() @@ -56,11 +59,11 @@ func TestWebsocketOrderPlace(t *testing.T) { out.Side = strings.ToLower(order.Buy.String()) _, err = g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out}, 0) require.ErrorIs(t, err, errInvalidAmount) - out.Amount = "-1" + out.Amount = "0.0003" out.Type = "limit" _, err = g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out}, 0) require.ErrorIs(t, err, errInvalidPrice) - out.Price = "-1" + out.Price = "20000" _, err = g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out}, 0) require.ErrorIs(t, err, common.ErrNotYetImplemented) @@ -79,12 +82,90 @@ func TestWebsocketOrderPlace(t *testing.T) { g.GetBase().API.AuthenticatedSupport = true g.GetBase().API.AuthenticatedWebsocketSupport = true - out.Account = "spot" - _, err = g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out}, asset.Spot) + // test single order + got, err := g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out}, asset.Spot) require.NoError(t, err) + require.NotEmpty(t, got) - // _, err = g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out, out}, asset.Spot) - // require.NoError(t, err) + // test batch orders + got, err = g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out, out}, asset.Spot) + require.NoError(t, err) + require.NotEmpty(t, got) +} + +var orderCancelError = []byte(`{"header":{"response_time":"1722405878406","status":"400","channel":"spot.order_cancel","event":"api","client_id":"14.203.57.50-0xc1e68ac6e0","conn_id":"0378a86ff109ca9a","trace_id":"b05be4753e751dff9175215ee020b578"},"data":{"errs":{"label":"INVALID_CURRENCY_PAIR","message":"label: INVALID_CURRENCY_PAIR, message: Invalid currency pair BTCUSD"}},"request_id":"1722405878175928500"}`) +var orderCancelSuccess = []byte(`{"header":{"response_time":"1722406252471","status":"200","channel":"spot.order_cancel","event":"api","client_id":"14.203.57.50-0xc2397b9e40"},"data":{"result":{"left":"0.0003","update_time":"1722406252","amount":"0.0003","create_time":"1722406069","price":"20000","finish_as":"cancelled","time_in_force":"gtc","currency_pair":"BTC_USDT","type":"limit","account":"spot","side":"buy","amend_text":"-","text":"t-1722406069442994700","status":"cancelled","iceberg":"0","filled_total":"0","id":"644913098758","fill_price":"0","update_time_ms":1722406252467,"create_time_ms":1722406069667}},"request_id":"1722406252236528200"}`) + +func TestWebsocketOrderCancel(t *testing.T) { + t.Parallel() + _, err := g.WebsocketOrderCancel(context.Background(), "", currency.EMPTYPAIR, "", 0) + require.ErrorIs(t, err, order.ErrOrderIDNotSet) + _, err = g.WebsocketOrderCancel(context.Background(), "1337", currency.EMPTYPAIR, "", 0) + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + btcusdt, err := currency.NewPairFromString("BTC_USDT") + require.NoError(t, err) + + _, err = g.WebsocketOrderCancel(context.Background(), "1337", btcusdt, "", 0) + require.ErrorIs(t, err, common.ErrNotYetImplemented) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + require.NoError(t, g.UpdateTradablePairs(context.Background(), false)) + for _, a := range g.GetAssetTypes(true) { + avail, err := g.GetAvailablePairs(a) + require.NoError(t, err) + if len(avail) > 1 { + avail = avail[:1] + } + require.NoError(t, g.SetPairs(avail, a, true)) + } + require.NoError(t, g.Websocket.Connect()) + g.GetBase().API.AuthenticatedSupport = true + g.GetBase().API.AuthenticatedWebsocketSupport = true + + got, err := g.WebsocketOrderCancel(context.Background(), "644913098758", btcusdt, "", asset.Spot) + require.NoError(t, err) + require.NotEmpty(t, got) +} + +var cancelAllfailed = []byte(`{"header":{"response_time":"1722407703038","status":"200","channel":"spot.order_cancel_ids","event":"api","client_id":"14.203.57.50-0xc36ba50dc0"},"data":{"result":[{"currency_pair":"BTC_USDT","id":"644913098758","label":"ORDER_NOT_FOUND","message":"Order not found"}]},"request_id":"1722407702811217700"}`) +var cancelAllSuccess = []byte(`{"header":{"response_time":"1722407800393","status":"200","channel":"spot.order_cancel_ids","event":"api","client_id":"14.203.57.50-0xc0ae1ed8c0"},"data":{"result":[{"currency_pair":"BTC_USDT","id":"644913101755","succeeded":true}]},"request_id":"1722407800174417400"}`) + +func TestWebsocketOrderCancelAllByIDs(t *testing.T) { + t.Parallel() + out := WebsocketOrderCancelRequest{} + _, err := g.WebsocketOrderCancelAllByIDs(context.Background(), []WebsocketOrderCancelRequest{out}, 0) + require.ErrorIs(t, err, order.ErrOrderIDNotSet) + out.OrderID = "1337" + _, err = g.WebsocketOrderCancelAllByIDs(context.Background(), []WebsocketOrderCancelRequest{out}, 0) + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + out.Pair, err = currency.NewPairFromString("BTC_USDT") + require.NoError(t, err) + + _, err = g.WebsocketOrderCancelAllByIDs(context.Background(), []WebsocketOrderCancelRequest{out}, 0) + require.ErrorIs(t, err, common.ErrNotYetImplemented) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + require.NoError(t, g.UpdateTradablePairs(context.Background(), false)) + for _, a := range g.GetAssetTypes(true) { + avail, err := g.GetAvailablePairs(a) + require.NoError(t, err) + if len(avail) > 1 { + avail = avail[:1] + } + require.NoError(t, g.SetPairs(avail, a, true)) + } + require.NoError(t, g.Websocket.Connect()) + g.GetBase().API.AuthenticatedSupport = true + g.GetBase().API.AuthenticatedWebsocketSupport = true + + out.OrderID = "644913101755" + got, err := g.WebsocketOrderCancelAllByIDs(context.Background(), []WebsocketOrderCancelRequest{out}, asset.Spot) + require.NoError(t, err) + require.NotEmpty(t, got) - time.Sleep(time.Second * 5) + fmt.Printf("%+v\n", got) } diff --git a/exchanges/gateio/websocket_request_types.go b/exchanges/gateio/websocket_request_types.go new file mode 100644 index 00000000000..a766b37a24e --- /dev/null +++ b/exchanges/gateio/websocket_request_types.go @@ -0,0 +1,120 @@ +package gateio + +import ( + "encoding/json" + + "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/types" +) + +// WebsocketAPIResponse defines a general websocket response for api calls +type WebsocketAPIResponse struct { + Header Header `json:"header"` + Data json.RawMessage `json:"data"` +} + +// Header defines a websocket header +type Header struct { + ResponseTime Time `json:"response_time"` + Status string `json:"status"` + Channel string `json:"channel"` + Event string `json:"event"` + ClientID string `json:"client_id"` + ConnectionID string `json:"conn_id"` + TraceID string `json:"trace_id"` +} + +// WebsocketRequest defines a websocket request +type WebsocketRequest struct { + Time int64 `json:"time,omitempty"` + ID int64 `json:"id,omitempty"` + Channel string `json:"channel"` + Event string `json:"event"` + Payload WebsocketPayload `json:"payload"` +} + +// WebsocketPayload defines an individualised websocket payload +type WebsocketPayload struct { + RequestID string `json:"req_id,omitempty"` + // APIKey and signature are only required in the initial login request + // which is done when the connection is established. + APIKey string `json:"api_key,omitempty"` + Timestamp string `json:"timestamp,omitempty"` + Signature string `json:"signature,omitempty"` + RequestParam json.RawMessage `json:"req_param,omitempty"` +} + +// WebsocketErrors defines a websocket error +type WebsocketErrors struct { + Errors struct { + Label string `json:"label"` + Message string `json:"message"` + } `json:"errs"` +} + +// WebsocketLoginResponse defines a websocket login response when authenticating +// the connection. +type WebsocketLoginResponse struct { + Result struct { + APIKey string `json:"api_key"` + UID string `json:"uid"` + } `json:"result"` +} + +// WebsocketOrder defines a websocket order +type WebsocketOrder struct { + Text string `json:"text"` + CurrencyPair string `json:"currency_pair,omitempty"` + Type string `json:"type,omitempty"` + Account string `json:"account,omitempty"` + Side string `json:"side,omitempty"` + Amount string `json:"amount,omitempty"` + Price string `json:"price,omitempty"` + TimeInForce string `json:"time_in_force,omitempty"` + Iceberg string `json:"iceberg,omitempty"` + AutoBorrow bool `json:"auto_borrow,omitempty"` + AutoRepay bool `json:"auto_repay,omitempty"` + StpAct string `json:"stp_act,omitempty"` +} + +// WebsocketOrderResponse defines a websocket order response +type WebsocketOrderResponse struct { + Left types.Number `json:"left"` + UpdateTime Time `json:"update_time"` + Amount types.Number `json:"amount"` + CreateTime Time `json:"create_time"` + Price types.Number `json:"price"` + FinishAs string `json:"finish_as"` + TimeInForce string `json:"time_in_force"` + CurrencyPair currency.Pair `json:"currency_pair"` + Type string `json:"type"` + Account string `json:"account"` + Side string `json:"side"` + AmendText string `json:"amend_text"` + Text string `json:"text"` + Status string `json:"status"` + Iceberg types.Number `json:"iceberg"` + FilledTotal types.Number `json:"filled_total"` + ID string `json:"id"` + FillPrice types.Number `json:"fill_price"` + UpdateTimeMs Time `json:"update_time_ms"` + CreateTimeMs Time `json:"create_time_ms"` + Fee types.Number `json:"fee"` + FeeCurrency currency.Code `json:"fee_currency"` + PointFee types.Number `json:"point_fee"` + GTFee types.Number `json:"gt_fee"` + GTMakerFee types.Number `json:"gt_maker_fee"` + GTTakerFee types.Number `json:"gt_taker_fee"` + GTDiscount bool `json:"gt_discount"` + RebatedFee types.Number `json:"rebated_fee"` + RebatedFeeCurrency currency.Code `json:"rebated_fee_currency"` + STPID int `json:"stp_id"` + STPAct string `json:"stp_act"` +} + +// WebsocketOrderCancelRequest defines a websocket order cancel request +type WebsocketOrderCancelRequest struct { + OrderID string `json:"id"` // This require id tag not order_id + Pair currency.Pair `json:"currency_pair"` + Account string `json:"account,omitempty"` +} diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index dcd6e0064a5..13ab4f1d98d 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -21,7 +21,7 @@ type Connection interface { SetupPingHandler(PingHandler) GenerateMessageID(highPrecision bool) int64 SendMessageReturnResponse(ctx context.Context, signature any, request any) ([]byte, error) - SendMessageReturnResponses(ctx context.Context, signature any, request any, expected int) ([][]byte, error) + SendMessageReturnResponses(ctx context.Context, signature any, request any, expected int, isFinalMessage ...Inspector) ([][]byte, error) SendRawMessage(messageType int, message []byte) error SetURL(string) SetProxy(string) @@ -29,6 +29,9 @@ type Connection interface { Shutdown() error } +// Inspector is a hook that allows for custom message inspection +type Inspector func([]byte) bool + // Response defines generalised data from the stream connection type Response struct { Type int diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 35343b22d0c..8c9fdf0f7c7 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -1329,48 +1329,72 @@ var ErrRequestRouteNotSet = errors.New("request route not set") var ErrSignatureNotSet = errors.New("signature not set") var ErrRequestPayloadNotSet = errors.New("request payload not set") -// SendRequest sends a request to a specific route and unmarhsals the response into the result -func (w *Websocket) SendRequest(_ context.Context, routeID string, signature, payload, result any) error { +// SendRequest sends a request to a specific route and unmarhsals the response +// into the result. NOTE: Only for multi connection management. +func (w *Websocket) SendRequest(ctx context.Context, routeID string, signature, payload, result any) error { if w == nil { return fmt.Errorf("%w: Websocket", common.ErrNilPointer) } - if routeID == "" { - return ErrRequestRouteNotSet - } + if signature == nil { return ErrSignatureNotSet } if payload == nil { return ErrRequestPayloadNotSet } + + outbound, err := w.GetOutboundConnection(routeID) + if err != nil { + return err + } + + // if w.verbose { + display, _ := json.Marshal(payload) + log.Debugf(log.WebsocketMgr, "%s websocket: sending request to %s. Data: %v", w.exchangeName, routeID, string(display)) + // } + + resp, err := outbound.SendMessageReturnResponse(ctx, signature, payload) + if err != nil { + return err + } + + // if w.verbose { + log.Debugf(log.WebsocketMgr, "%s websocket: received response from %s. Data: %s", w.exchangeName, routeID, resp) + // } + return json.Unmarshal(resp, result) +} + +var errCannotObtainOutboundConnection = errors.New("cannot obtain outbound connection") + +// GetOutboundConnection returns a connection specifically for outbound requests +// for multi connection management. TODO: Upgrade routeID so that if there is +// a URL change it can be handled. +func (w *Websocket) GetOutboundConnection(routeID string) (Connection, error) { + if w == nil { + return nil, fmt.Errorf("%w: Websocket", common.ErrNilPointer) + } + if !w.IsConnected() { - return ErrNotConnected + return nil, ErrNotConnected + } + + if !w.useMultiConnectionManagement { + return nil, fmt.Errorf("%s: multi connection management not enabled %w please use exported Conn and AuthConn fields", w.exchangeName, errCannotObtainOutboundConnection) + } + + if routeID == "" { + return nil, ErrRequestRouteNotSet } for x := range w.connectionManager { if w.connectionManager[x].Setup.URL != routeID { continue } - if w.connectionManager[x].Connection == nil { - return fmt.Errorf("%s: %w", w.connectionManager[x].Setup.URL, ErrNotConnected) - } - - // if w.verbose { - display, _ := json.Marshal(payload) - log.Debugf(log.WebsocketMgr, "%s websocket: sending request to %s. Data: %v", w.exchangeName, routeID, string(display)) - // } - - resp, err := w.connectionManager[x].Connection.SendMessageReturnResponse(signature, payload) - if err != nil { - return err + return nil, fmt.Errorf("%s: %w", w.connectionManager[x].Setup.URL, ErrNotConnected) } - - // if w.verbose { - log.Debugf(log.WebsocketMgr, "%s websocket: received response from %s. Data: %s", w.exchangeName, routeID, resp) - // } - return json.Unmarshal(resp, result) + return w.connectionManager[x].Connection, nil } - return fmt.Errorf("%w: %s", ErrRequestRouteNotFound, routeID) + return nil, fmt.Errorf("%w: %s", ErrRequestRouteNotFound, routeID) } diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 2859ce0fbd6..3252c7600c2 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -288,7 +288,7 @@ func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, sig // SendMessageReturnResponses will send a WS message to the connection and wait for N responses // An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked -func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, signature, request any, expected int) ([][]byte, error) { +func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, signature, request any, expected int, isFinalMessage ...Inspector) ([][]byte, error) { outbound, err := json.Marshal(request) if err != nil { return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err) @@ -319,6 +319,12 @@ func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, si w.Match.RemoveSignature(signature) err = ctx.Err() } + // Checks recently received message to determine if this is in fact the + // final message in a sequence of messages. + if len(isFinalMessage) == 1 && isFinalMessage[0](resps[len(resps)-1]) { + w.Match.RemoveSignature(signature) + break + } } timeout.Stop() From 3f3da46ba6d2bf07a572e6f6542be9660842695f Mon Sep 17 00:00:00 2001 From: shazbert Date: Wed, 31 Jul 2024 20:18:45 +1000 Subject: [PATCH 041/138] gateio: finish spot --- exchanges/gateio/gateio_test.go | 30 ++-- exchanges/gateio/websocket_request.go | 154 +++++++++++++------- exchanges/gateio/websocket_request_test.go | 120 ++++++++++++++- exchanges/gateio/websocket_request_types.go | 25 ++++ 4 files changed, 264 insertions(+), 65 deletions(-) diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index 043fe9471fc..97618b5db1a 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -144,21 +144,21 @@ func TestGetAccountInfo(t *testing.T) { if err != nil { t.Error("GetAccountInfo() error", err) } - // if _, err := g.UpdateAccountInfo(context.Background(), asset.Margin); err != nil { - // t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) - // } - // if _, err := g.UpdateAccountInfo(context.Background(), asset.CrossMargin); err != nil { - // t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) - // } - // if _, err := g.UpdateAccountInfo(context.Background(), asset.Options); err != nil { - // t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) - // } - // if _, err := g.UpdateAccountInfo(context.Background(), asset.Futures); err != nil { - // t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) - // } - // if _, err := g.UpdateAccountInfo(context.Background(), asset.DeliveryFutures); err != nil { - // t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) - // } + if _, err := g.UpdateAccountInfo(context.Background(), asset.Margin); err != nil { + t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) + } + if _, err := g.UpdateAccountInfo(context.Background(), asset.CrossMargin); err != nil { + t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) + } + if _, err := g.UpdateAccountInfo(context.Background(), asset.Options); err != nil { + t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) + } + if _, err := g.UpdateAccountInfo(context.Background(), asset.Futures); err != nil { + t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) + } + if _, err := g.UpdateAccountInfo(context.Background(), asset.DeliveryFutures); err != nil { + t.Errorf("%s UpdateAccountInfo() error %v", g.Name, err) + } } func TestWithdraw(t *testing.T) { diff --git a/exchanges/gateio/websocket_request.go b/exchanges/gateio/websocket_request.go index 2e0ccd9a700..a049f006e66 100644 --- a/exchanges/gateio/websocket_request.go +++ b/exchanges/gateio/websocket_request.go @@ -6,6 +6,7 @@ import ( "crypto/sha512" "encoding/hex" "encoding/json" + "errors" "fmt" "strconv" "strings" @@ -18,8 +19,11 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/stream" ) -var errBatchSliceEmpty = fmt.Errorf("batch cannot be empty") -var errNoOrdersToCancel = fmt.Errorf("no orders to cancel") +var ( + errBatchSliceEmpty = errors.New("batch cannot be empty") + errNoOrdersToCancel = errors.New("no orders to cancel") + errEdgeCaseIssue = errors.New("edge case issue") +) // GetWebsocketRoute returns the route for a websocket request, this is a POC // for the websocket wrapper. @@ -116,21 +120,13 @@ func (g *Gateio) WebsocketOrderPlace(ctx context.Context, batch []WebsocketOrder } if len(batch) == 1 { - singleOutbound, err := json.Marshal(batch[0]) - if err != nil { - return nil, err - } var singleResponse WebsocketOrderResponse - err = g.SendWebsocketRequest(ctx, "spot.order_place", route, singleOutbound, &singleResponse, 2) + err = g.SendWebsocketRequest(ctx, "spot.order_place", route, batch[0], &singleResponse, 2) return []WebsocketOrderResponse{singleResponse}, err } - multiOutbound, err := json.Marshal(batch) - if err != nil { - return nil, err - } var resp []WebsocketOrderResponse - err = g.SendWebsocketRequest(ctx, "spot.order_place", route, multiOutbound, &resp, 2) + err = g.SendWebsocketRequest(ctx, "spot.order_place", route, batch, &resp, 2) return resp, err } @@ -147,7 +143,7 @@ func (g *Gateio) WebsocketOrderCancel(ctx context.Context, orderID string, pair return nil, err } - out := struct { + params := &struct { OrderID string `json:"order_id"` // This requires order_id tag Pair string `json:"pair"` Account string `json:"account,omitempty"` @@ -156,22 +152,12 @@ func (g *Gateio) WebsocketOrderCancel(ctx context.Context, orderID string, pair Pair: pair.String(), Account: account, } - outbound, err := json.Marshal(out) - if err != nil { - return nil, err - } + var resp WebsocketOrderResponse - err = g.SendWebsocketRequest(ctx, "spot.order_cancel", route, outbound, &resp, 1) + err = g.SendWebsocketRequest(ctx, "spot.order_cancel", route, params, &resp, 1) return &resp, err } -type WebsocketCancellAllResponse struct { - Pair currency.Pair `json:"currency_pair"` - Label string `json:"label"` - Message string `json:"message"` - Succeeded bool `json:"succeeded"` -} - // WebsocketOrderCancelAllByIDs cancels multiple orders via the websocket func (g *Gateio) WebsocketOrderCancelAllByIDs(ctx context.Context, o []WebsocketOrderCancelRequest, a asset.Item) ([]WebsocketCancellAllResponse, error) { if len(o) == 0 { @@ -192,48 +178,120 @@ func (g *Gateio) WebsocketOrderCancelAllByIDs(ctx context.Context, o []Websocket return nil, err } - outbound, err := json.Marshal(o) + var resp []WebsocketCancellAllResponse + err = g.SendWebsocketRequest(ctx, "spot.order_cancel_ids", route, o, &resp, 2) + return resp, err +} + +// WebsocketOrderCancelAllByPair cancels all orders for a specific pair +func (g *Gateio) WebsocketOrderCancelAllByPair(ctx context.Context, pair currency.Pair, side order.Side, account string, a asset.Item) ([]WebsocketOrderResponse, error) { + if !pair.IsEmpty() && side == order.UnknownSide { + return nil, fmt.Errorf("%w: side cannot be unknown when pair is set as this will purge *ALL* open orders", errEdgeCaseIssue) + } + + sideStr := "" + if side != order.UnknownSide { + sideStr = side.Lower() + } + + route, err := g.GetWebsocketRoute(a) if err != nil { return nil, err } - var resp []WebsocketCancellAllResponse - err = g.SendWebsocketRequest(ctx, "spot.order_cancel_ids", route, outbound, &resp, 2) + params := &WebsocketCancelParam{ + Pair: pair, + Side: sideStr, + Account: account, + } + + var resp []WebsocketOrderResponse + err = g.SendWebsocketRequest(ctx, "spot.order_cancel_cp", route, params, &resp, 1) return resp, err } -// OrderCancelAllByIDList -// OrderCancelAllByPair -// OrderAmend -// OrderStatus +// WebsocketOrderAmend amends an order via the websocket connection +func (g *Gateio) WebsocketOrderAmend(ctx context.Context, amend *WebsocketAmendOrder, a asset.Item) (*WebsocketOrderResponse, error) { + if amend == nil { + return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, amend) + } + + if amend.OrderID == "" { + return nil, order.ErrOrderIDNotSet + } + + if amend.Pair.IsEmpty() { + return nil, currency.ErrCurrencyPairEmpty + } + + if amend.Amount == "" && amend.Price == "" { + return nil, fmt.Errorf("%w: amount or price must be set", errInvalidAmount) + } + + route, err := g.GetWebsocketRoute(a) + if err != nil { + return nil, err + } + + var resp WebsocketOrderResponse + err = g.SendWebsocketRequest(ctx, "spot.order_amend", route, amend, &resp, 1) + return &resp, err +} + +// WebsocketGetOrderStatus gets the status of an order via the websocket connection +func (g *Gateio) WebsocketGetOrderStatus(ctx context.Context, orderID string, pair currency.Pair, account string, a asset.Item) (*WebsocketOrderResponse, error) { + if orderID == "" { + return nil, order.ErrOrderIDNotSet + } + if pair.IsEmpty() { + return nil, currency.ErrCurrencyPairEmpty + } + route, err := g.GetWebsocketRoute(a) + if err != nil { + return nil, err + } + + params := &struct { + OrderID string `json:"order_id"` // This requires order_id tag + Pair string `json:"pair"` + Account string `json:"account,omitempty"` + }{ + OrderID: orderID, + Pair: pair.String(), + Account: account, + } + + var resp WebsocketOrderResponse + err = g.SendWebsocketRequest(ctx, "spot.order_status", route, params, &resp, 1) + return &resp, err +} // SendWebsocketRequest sends a websocket request to the exchange -func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel, route string, params json.RawMessage, result any, expectedResponses int) error { - conn, err := g.Websocket.GetOutboundConnection(route) +func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel, route string, params, result any, expectedResponses int) error { + paramPayload, err := json.Marshal(params) if err != nil { return err } - tn := time.Now() - mainPayload := WebsocketPayload{ - // This request ID associated with the payload is the match to the - // response. - RequestID: strconv.FormatInt(tn.UnixNano(), 10), - RequestParam: params, - Timestamp: strconv.FormatInt(tn.Unix(), 10), + conn, err := g.Websocket.GetOutboundConnection(route) + if err != nil { + return err } - request := WebsocketRequest{ + tn := time.Now() + request := &WebsocketRequest{ Time: tn.Unix(), Channel: channel, Event: "api", - Payload: mainPayload, + Payload: WebsocketPayload{ + // This request ID associated with the payload is the match to the + // response. + RequestID: strconv.FormatInt(tn.UnixNano(), 10), + RequestParam: paramPayload, + Timestamp: strconv.FormatInt(tn.Unix(), 10), + }, } - out, _ := json.Marshal(request) - - fmt.Println("outbound:", string(out)) - responses, err := conn.SendMessageReturnResponses(ctx, request.Payload.RequestID, request, expectedResponses, InspectPayloadForAck) if err != nil { return err @@ -249,8 +307,6 @@ func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel, route string // from that as the next response won't come anyway. endResponse := responses[len(responses)-1] - fmt.Println("response:", string(endResponse)) - err = json.Unmarshal(endResponse, &inbound) if err != nil { return err diff --git a/exchanges/gateio/websocket_request_test.go b/exchanges/gateio/websocket_request_test.go index c74a10f29d8..a86ddcd7b63 100644 --- a/exchanges/gateio/websocket_request_test.go +++ b/exchanges/gateio/websocket_request_test.go @@ -166,6 +166,124 @@ func TestWebsocketOrderCancelAllByIDs(t *testing.T) { got, err := g.WebsocketOrderCancelAllByIDs(context.Background(), []WebsocketOrderCancelRequest{out}, asset.Spot) require.NoError(t, err) require.NotEmpty(t, got) +} + +var cancelAllByPairSuccess = []byte(`{"header":{"response_time":"1722415590482","status":"200","channel":"spot.order_cancel_cp","event":"api","client_id":"58.169.146.133-0xc028f00b00"},"data":{"result":[{"left":"0.0003","update_time":"1722415590","amount":"0.0003","create_time":"1722406069","price":"20000","finish_as":"cancelled","time_in_force":"gtc","currency_pair":"BTC_USDT","type":"limit","account":"spot","side":"buy","amend_text":"-","text":"t-1722406069759058701","status":"cancelled","iceberg":"0","filled_total":"0","id":"644913101780","fill_price":"0","update_time_ms":1722415590471,"create_time_ms":1722406069992}]},"request_id":"1722415590230464500"}`) + +func TestWebsocketOrderCancelAllByPair(t *testing.T) { + t.Parallel() + pair, err := currency.NewPairFromString("LTC_USDT") + require.NoError(t, err) + + _, err = g.WebsocketOrderCancelAllByPair(context.Background(), pair, 0, "", 0) + require.ErrorIs(t, err, errEdgeCaseIssue) - fmt.Printf("%+v\n", got) + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + require.NoError(t, g.UpdateTradablePairs(context.Background(), false)) + for _, a := range g.GetAssetTypes(true) { + avail, err := g.GetAvailablePairs(a) + require.NoError(t, err) + if len(avail) > 1 { + avail = avail[:1] + } + require.NoError(t, g.SetPairs(avail, a, true)) + } + require.NoError(t, g.Websocket.Connect()) + g.GetBase().API.AuthenticatedSupport = true + g.GetBase().API.AuthenticatedWebsocketSupport = true + + got, err := g.WebsocketOrderCancelAllByPair(context.Background(), currency.EMPTYPAIR, order.Buy, "", asset.Spot) + require.NoError(t, err) + require.NotEmpty(t, got) +} + +var amendOrderError = []byte(`{"header":{"response_time":"1722420643127","status":"404","channel":"spot.order_amend","event":"api","client_id":"58.169.146.133-0xc1e615e6e0","conn_id":"71eb27ad8803a9bd","trace_id":"4d80b11b184b49bd540abd039f42a84d"},"data":{"errs":{"label":"ORDER_NOT_FOUND","message":"label: ORDER_NOT_FOUND, message: Order not found"}},"request_id":"1722420642896203600"}`) +var ammendOrderSuccess = []byte(`"header":{"response_time":"1722420772699","status":"200","channel":"spot.order_amend","event":"api","client_id":"58.169.146.133-0xc08c7c2f20"},"data":{"result":{"left":"0.0004","update_time":"1722420772","amount":"0.0004","create_time":"1722420733","price":"20000","finish_as":"open","time_in_force":"gtc","currency_pair":"BTC_USDT","type":"limit","account":"spot","side":"buy","amend_text":"-","text":"t-1722420733733908700","status":"open","iceberg":"0","filled_total":"0","id":"645029162673","fill_price":"0","update_time_ms":1722420772698,"create_time_ms":1722420733966}},"request_id":"1722420772476042600"}`) + +func TestWebsocketOrderAmend(t *testing.T) { + t.Parallel() + + _, err := g.WebsocketOrderAmend(context.Background(), nil, 0) + require.ErrorIs(t, err, common.ErrNilPointer) + + amend := &WebsocketAmendOrder{} + _, err = g.WebsocketOrderAmend(context.Background(), amend, 0) + require.ErrorIs(t, err, order.ErrOrderIDNotSet) + + amend.OrderID = "1337" + _, err = g.WebsocketOrderAmend(context.Background(), amend, 0) + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + amend.Pair, err = currency.NewPairFromString("BTC_USDT") + require.NoError(t, err) + + _, err = g.WebsocketOrderAmend(context.Background(), amend, 0) + require.ErrorIs(t, err, errInvalidAmount) + + amend.Amount = "0.0004" + + _, err = g.WebsocketOrderAmend(context.Background(), amend, 0) + require.ErrorIs(t, err, common.ErrNotYetImplemented) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + require.NoError(t, g.UpdateTradablePairs(context.Background(), false)) + for _, a := range g.GetAssetTypes(true) { + avail, err := g.GetAvailablePairs(a) + require.NoError(t, err) + if len(avail) > 1 { + avail = avail[:1] + } + require.NoError(t, g.SetPairs(avail, a, true)) + } + require.NoError(t, g.Websocket.Connect()) + g.GetBase().API.AuthenticatedSupport = true + g.GetBase().API.AuthenticatedWebsocketSupport = true + + amend.OrderID = "645029162673" + got, err := g.WebsocketOrderAmend(context.Background(), amend, asset.Spot) + require.NoError(t, err) + require.NotEmpty(t, got) +} + +var getOrderStatusError = []byte(`{"header":{"response_time":"1722417357718","status":"404","channel":"spot.order_status","event":"api","client_id":"58.169.146.133-0xc0e6013600","conn_id":"8ae56147f8a55b08","trace_id":"127ac043f3a762ae88b746122aba5e3b"},"data":{"errs":{"label":"ORDER_NOT_FOUND","message":"label: ORDER_NOT_FOUND, message: Order with ID 644999648436 not found"}},"request_id":"1722417357478800700"}`) +var getOrderStatusSuccess = []byte(`{"header":{"response_time":"1722417915985","status":"200","channel":"spot.order_status","event":"api","client_id":"58.169.146.133-0xc06e7ff1e0"},"data":{"result":{"left":"0.0003","update_time":"1722417700","amount":"0.0003","create_time":"1722416858","price":"20000","finish_as":"cancelled","time_in_force":"gtc","currency_pair":"BTC_USDT","type":"limit","account":"spot","side":"buy","amend_text":"-","text":"t-1722416858697102100","status":"cancelled","iceberg":"0","filled_total":"0","id":"644999650452","fill_price":"0","update_time_ms":1722417700653,"create_time_ms":1722416858942}},"request_id":"1722417915744467800"}`) + +func TestWebsocketGetOrderStatus(t *testing.T) { + t.Parallel() + + _, err := g.WebsocketGetOrderStatus(context.Background(), "", currency.EMPTYPAIR, "", 0) + require.ErrorIs(t, err, order.ErrOrderIDNotSet) + + _, err = g.WebsocketGetOrderStatus(context.Background(), "1337", currency.EMPTYPAIR, "", 0) + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + pair, err := currency.NewPairFromString("LTC_USDT") + require.NoError(t, err) + + _, err = g.WebsocketGetOrderStatus(context.Background(), "1337", pair, "", 0) + require.ErrorIs(t, err, common.ErrNotYetImplemented) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + require.NoError(t, g.UpdateTradablePairs(context.Background(), false)) + for _, a := range g.GetAssetTypes(true) { + avail, err := g.GetAvailablePairs(a) + require.NoError(t, err) + if len(avail) > 1 { + avail = avail[:1] + } + require.NoError(t, g.SetPairs(avail, a, true)) + } + require.NoError(t, g.Websocket.Connect()) + g.GetBase().API.AuthenticatedSupport = true + g.GetBase().API.AuthenticatedWebsocketSupport = true + + pair, err = currency.NewPairFromString("BTC_USDT") + require.NoError(t, err) + + got, err := g.WebsocketGetOrderStatus(context.Background(), "644999650452", pair, "", asset.Spot) + require.NoError(t, err) + require.NotEmpty(t, got) } diff --git a/exchanges/gateio/websocket_request_types.go b/exchanges/gateio/websocket_request_types.go index a766b37a24e..d662a3aa912 100644 --- a/exchanges/gateio/websocket_request_types.go +++ b/exchanges/gateio/websocket_request_types.go @@ -118,3 +118,28 @@ type WebsocketOrderCancelRequest struct { Pair currency.Pair `json:"currency_pair"` Account string `json:"account,omitempty"` } + +// WebsocketOrderCancelResponse defines a websocket order cancel response +type WebsocketCancellAllResponse struct { + Pair currency.Pair `json:"currency_pair"` + Label string `json:"label"` + Message string `json:"message"` + Succeeded bool `json:"succeeded"` +} + +// WebsocketCancelParam is a struct to hold the parameters for cancelling orders +type WebsocketCancelParam struct { + Pair currency.Pair `json:"pair"` + Side string `json:"side"` + Account string `json:"account,omitempty"` +} + +// WebsocketAmendOrder defines a websocket amend order +type WebsocketAmendOrder struct { + OrderID string `json:"order_id"` + Pair currency.Pair `json:"currency_pair"` + Account string `json:"account,omitempty"` + AmendText string `json:"amend_text,omitempty"` + Price string `json:"price,omitempty"` + Amount string `json:"amount,omitempty"` +} From 6b9c4ad92ca969a3002c441803722e6b606f8a65 Mon Sep 17 00:00:00 2001 From: shazbert Date: Thu, 1 Aug 2024 06:14:24 +1000 Subject: [PATCH 042/138] fix up tests in gateio --- exchanges/gateio/gateio_test.go | 2 +- exchanges/gateio/websocket_request.go | 9 ++ exchanges/gateio/websocket_request_test.go | 175 ++++++++------------- 3 files changed, 74 insertions(+), 112 deletions(-) diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index 97618b5db1a..cc785baf784 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -33,7 +33,7 @@ import ( const ( apiKey = "" apiSecret = "" - canManipulateRealOrders = true + canManipulateRealOrders = false ) var g = &Gateio{} diff --git a/exchanges/gateio/websocket_request.go b/exchanges/gateio/websocket_request.go index a049f006e66..734a2855564 100644 --- a/exchanges/gateio/websocket_request.go +++ b/exchanges/gateio/websocket_request.go @@ -23,6 +23,7 @@ var ( errBatchSliceEmpty = errors.New("batch cannot be empty") errNoOrdersToCancel = errors.New("no orders to cancel") errEdgeCaseIssue = errors.New("edge case issue") + errChannelEmpty = errors.New("channel cannot be empty") ) // GetWebsocketRoute returns the route for a websocket request, this is a POC @@ -38,6 +39,14 @@ func (g *Gateio) GetWebsocketRoute(a asset.Item) (string, error) { // WebsocketLogin authenticates the websocket connection func (g *Gateio) WebsocketLogin(ctx context.Context, conn stream.Connection, channel string) (*WebsocketLoginResponse, error) { + if conn == nil { + return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, conn) + } + + if channel == "" { + return nil, errChannelEmpty + } + creds, err := g.GetCredentials(ctx) if err != nil { return nil, err diff --git a/exchanges/gateio/websocket_request_test.go b/exchanges/gateio/websocket_request_test.go index a86ddcd7b63..060d293099c 100644 --- a/exchanges/gateio/websocket_request_test.go +++ b/exchanges/gateio/websocket_request_test.go @@ -2,50 +2,43 @@ package gateio import ( "context" - "fmt" "strings" "testing" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" + "github.com/thrasher-corp/gocryptotrader/exchanges/stream" + testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" ) -var loginResponse = []byte(`{"header":{"response_time":"1722227146659","status":"200","channel":"spot.login","event":"api","client_id":"14.203.57.50-0xc11df96f20"},"data":{"result":{"api_key":"4960099442600b4cfefa48ac72dacca0","uid":"2365748"}},"request_id":"1722227146427268900"}`) - func TestWebsocketLogin(t *testing.T) { t.Parallel() - _, err := g.WebsocketLogin(context.Background(), nil, "bro.Login") - require.ErrorIs(t, err, common.ErrNotYetImplemented) + _, err := g.WebsocketLogin(context.Background(), nil, "") + require.ErrorIs(t, err, common.ErrNilPointer) - require.NoError(t, g.UpdateTradablePairs(context.Background(), false)) - for _, a := range g.GetAssetTypes(true) { - avail, err := g.GetAvailablePairs(a) - require.NoError(t, err) - if len(avail) > 1 { - avail = avail[:1] - } - require.NoError(t, g.SetPairs(avail, a, true)) - } - require.NoError(t, g.Websocket.Connect()) - g.GetBase().API.AuthenticatedSupport = true - g.GetBase().API.AuthenticatedWebsocketSupport = true + _, err = g.WebsocketLogin(context.Background(), &stream.WebsocketConnection{}, "") + require.ErrorIs(t, err, errChannelEmpty) - got, err := g.WebsocketLogin(context.Background(), nil, "bro.Login") + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := GetWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + route, err := g.GetWebsocketRoute(asset.Spot) require.NoError(t, err) - fmt.Println(got) -} + demonstrationConn, err := g.Websocket.GetOutboundConnection(route) + require.NoError(t, err) -var orderError = []byte(`{"header":{"response_time":"1722392009059","status":"400","channel":"spot.order_place","event":"api","client_id":"14.203.57.50-0xc0b61a0840","conn_id":"b5cd175a189984a6","trace_id":"f56a31478d7c6ce4ddaea3b337263233"},"data":{"errs":{"label":"INVALID_ARGUMENT","message":"OrderPlace request params error"}},"request_id":"1722392008842968100"}`) -var orderAcceptedResp = []byte(`{"header":{"response_time":"1722393719499","status":"200","channel":"spot.order_place","event":"api","client_id":"14.203.57.50-0xc213dab340","conn_id":"bfcbe154b8520050","trace_id":"74fbfd701d54bfe207ec79b6d2736b3a"},"data":{"result":{"req_id":"1722393719287158300","api_key":"","timestamp":"","signature":"","trace_id":"0e30c04e4e7499bccde8f83990d7168a","req_header":{"trace_id":"0e30c04e4e7499bccde8f83990d7168a"},"req_param":[{"text":"apiv4-ws","currency_pair":"BTC_USDT","type":"limit","side":"BUY","amount":"-1","price":"-1"}]}},"request_id":"1722393719287158300","ack":true}`) -var orderSecondResponseError = []byte(`{"header":{"response_time":"1722400001367","status":"400","channel":"spot.order_place","event":"api","client_id":"14.203.57.50-0xc12e5e4f20","conn_id":"4ddf3b1b45523bc3","trace_id":"8cca91e29b405e334b1901463c36afe1"},"data":{"errs":{"label":"INVALID_PARAM_VALUE","message":"label: INVALID_PARAM_VALUE, message: Your order size 0.200000 USDT is too small. The minimum is 3 USDT"}},"request_id":"1722400001142974600"}`) -var orderSecondResponseSuccess = []byte(`{"header":{"response_time":"1722400187811","status":"200","channel":"spot.order_place","event":"api","client_id":"14.203.57.50-0xc1b81a7340"},"data":{"result":{"left":"0.0003","update_time":"1722400187","amount":"0.0003","create_time":"1722400187","price":"20000","finish_as":"open","time_in_force":"gtc","currency_pair":"BTC_USDT","type":"limit","account":"spot","side":"buy","amend_text":"-","text":"t-1722400187564025900","status":"open","iceberg":"0","filled_total":"0","id":"644865690097","fill_price":"0","update_time_ms":1722400187807,"create_time_ms":1722400187807}},"request_id":"1722400187564025900"}`) -var orderBatchSuccess = []byte(`{"header":{"response_time":"1722402442822","status":"200","channel":"spot.order_place","event":"api","client_id":"14.203.57.50-0xc0e372e580"},"data":{"result":[{"account":"spot","status":"open","side":"buy","amount":"0.0003","id":"644883514616","create_time":"1722402442","update_time":"1722402442","text":"t-1722402442588484600","left":"0.0003","currency_pair":"BTC_USDT","type":"limit","finish_as":"open","price":"20000","time_in_force":"gtc","iceberg":"0","filled_total":"0","fill_price":"0","create_time_ms":1722402442819,"update_time_ms":1722402442819,"succeeded":true},{"account":"spot","status":"open","side":"buy","amount":"0.0003","id":"644883514625","create_time":"1722402442","update_time":"1722402442","text":"t-1722402442588484601","left":"0.0003","currency_pair":"BTC_USDT","type":"limit","finish_as":"open","price":"20000","time_in_force":"gtc","iceberg":"0","filled_total":"0","fill_price":"0","create_time_ms":1722402442821,"update_time_ms":1722402442821,"succeeded":true}]},"request_id":"172240244 -2588484600"}`) + got, err := g.WebsocketLogin(context.Background(), demonstrationConn, "spot.login") + require.NoError(t, err) + require.NotEmpty(t, got) +} func TestWebsocketOrderPlace(t *testing.T) { t.Parallel() @@ -69,18 +62,8 @@ func TestWebsocketOrderPlace(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) - require.NoError(t, g.UpdateTradablePairs(context.Background(), false)) - for _, a := range g.GetAssetTypes(true) { - avail, err := g.GetAvailablePairs(a) - require.NoError(t, err) - if len(avail) > 1 { - avail = avail[:1] - } - require.NoError(t, g.SetPairs(avail, a, true)) - } - require.NoError(t, g.Websocket.Connect()) - g.GetBase().API.AuthenticatedSupport = true - g.GetBase().API.AuthenticatedWebsocketSupport = true + testexch.UpdatePairsOnce(t, g) + g := GetWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes // test single order got, err := g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out}, asset.Spot) @@ -111,27 +94,14 @@ func TestWebsocketOrderCancel(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) - require.NoError(t, g.UpdateTradablePairs(context.Background(), false)) - for _, a := range g.GetAssetTypes(true) { - avail, err := g.GetAvailablePairs(a) - require.NoError(t, err) - if len(avail) > 1 { - avail = avail[:1] - } - require.NoError(t, g.SetPairs(avail, a, true)) - } - require.NoError(t, g.Websocket.Connect()) - g.GetBase().API.AuthenticatedSupport = true - g.GetBase().API.AuthenticatedWebsocketSupport = true + testexch.UpdatePairsOnce(t, g) + g := GetWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes got, err := g.WebsocketOrderCancel(context.Background(), "644913098758", btcusdt, "", asset.Spot) require.NoError(t, err) require.NotEmpty(t, got) } -var cancelAllfailed = []byte(`{"header":{"response_time":"1722407703038","status":"200","channel":"spot.order_cancel_ids","event":"api","client_id":"14.203.57.50-0xc36ba50dc0"},"data":{"result":[{"currency_pair":"BTC_USDT","id":"644913098758","label":"ORDER_NOT_FOUND","message":"Order not found"}]},"request_id":"1722407702811217700"}`) -var cancelAllSuccess = []byte(`{"header":{"response_time":"1722407800393","status":"200","channel":"spot.order_cancel_ids","event":"api","client_id":"14.203.57.50-0xc0ae1ed8c0"},"data":{"result":[{"currency_pair":"BTC_USDT","id":"644913101755","succeeded":true}]},"request_id":"1722407800174417400"}`) - func TestWebsocketOrderCancelAllByIDs(t *testing.T) { t.Parallel() out := WebsocketOrderCancelRequest{} @@ -149,18 +119,8 @@ func TestWebsocketOrderCancelAllByIDs(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) - require.NoError(t, g.UpdateTradablePairs(context.Background(), false)) - for _, a := range g.GetAssetTypes(true) { - avail, err := g.GetAvailablePairs(a) - require.NoError(t, err) - if len(avail) > 1 { - avail = avail[:1] - } - require.NoError(t, g.SetPairs(avail, a, true)) - } - require.NoError(t, g.Websocket.Connect()) - g.GetBase().API.AuthenticatedSupport = true - g.GetBase().API.AuthenticatedWebsocketSupport = true + testexch.UpdatePairsOnce(t, g) + g := GetWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes out.OrderID = "644913101755" got, err := g.WebsocketOrderCancelAllByIDs(context.Background(), []WebsocketOrderCancelRequest{out}, asset.Spot) @@ -168,8 +128,6 @@ func TestWebsocketOrderCancelAllByIDs(t *testing.T) { require.NotEmpty(t, got) } -var cancelAllByPairSuccess = []byte(`{"header":{"response_time":"1722415590482","status":"200","channel":"spot.order_cancel_cp","event":"api","client_id":"58.169.146.133-0xc028f00b00"},"data":{"result":[{"left":"0.0003","update_time":"1722415590","amount":"0.0003","create_time":"1722406069","price":"20000","finish_as":"cancelled","time_in_force":"gtc","currency_pair":"BTC_USDT","type":"limit","account":"spot","side":"buy","amend_text":"-","text":"t-1722406069759058701","status":"cancelled","iceberg":"0","filled_total":"0","id":"644913101780","fill_price":"0","update_time_ms":1722415590471,"create_time_ms":1722406069992}]},"request_id":"1722415590230464500"}`) - func TestWebsocketOrderCancelAllByPair(t *testing.T) { t.Parallel() pair, err := currency.NewPairFromString("LTC_USDT") @@ -180,27 +138,14 @@ func TestWebsocketOrderCancelAllByPair(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) - require.NoError(t, g.UpdateTradablePairs(context.Background(), false)) - for _, a := range g.GetAssetTypes(true) { - avail, err := g.GetAvailablePairs(a) - require.NoError(t, err) - if len(avail) > 1 { - avail = avail[:1] - } - require.NoError(t, g.SetPairs(avail, a, true)) - } - require.NoError(t, g.Websocket.Connect()) - g.GetBase().API.AuthenticatedSupport = true - g.GetBase().API.AuthenticatedWebsocketSupport = true + testexch.UpdatePairsOnce(t, g) + g := GetWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes got, err := g.WebsocketOrderCancelAllByPair(context.Background(), currency.EMPTYPAIR, order.Buy, "", asset.Spot) require.NoError(t, err) require.NotEmpty(t, got) } -var amendOrderError = []byte(`{"header":{"response_time":"1722420643127","status":"404","channel":"spot.order_amend","event":"api","client_id":"58.169.146.133-0xc1e615e6e0","conn_id":"71eb27ad8803a9bd","trace_id":"4d80b11b184b49bd540abd039f42a84d"},"data":{"errs":{"label":"ORDER_NOT_FOUND","message":"label: ORDER_NOT_FOUND, message: Order not found"}},"request_id":"1722420642896203600"}`) -var ammendOrderSuccess = []byte(`"header":{"response_time":"1722420772699","status":"200","channel":"spot.order_amend","event":"api","client_id":"58.169.146.133-0xc08c7c2f20"},"data":{"result":{"left":"0.0004","update_time":"1722420772","amount":"0.0004","create_time":"1722420733","price":"20000","finish_as":"open","time_in_force":"gtc","currency_pair":"BTC_USDT","type":"limit","account":"spot","side":"buy","amend_text":"-","text":"t-1722420733733908700","status":"open","iceberg":"0","filled_total":"0","id":"645029162673","fill_price":"0","update_time_ms":1722420772698,"create_time_ms":1722420733966}},"request_id":"1722420772476042600"}`) - func TestWebsocketOrderAmend(t *testing.T) { t.Parallel() @@ -226,20 +171,8 @@ func TestWebsocketOrderAmend(t *testing.T) { _, err = g.WebsocketOrderAmend(context.Background(), amend, 0) require.ErrorIs(t, err, common.ErrNotYetImplemented) - sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) - - require.NoError(t, g.UpdateTradablePairs(context.Background(), false)) - for _, a := range g.GetAssetTypes(true) { - avail, err := g.GetAvailablePairs(a) - require.NoError(t, err) - if len(avail) > 1 { - avail = avail[:1] - } - require.NoError(t, g.SetPairs(avail, a, true)) - } - require.NoError(t, g.Websocket.Connect()) - g.GetBase().API.AuthenticatedSupport = true - g.GetBase().API.AuthenticatedWebsocketSupport = true + testexch.UpdatePairsOnce(t, g) + g := GetWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes amend.OrderID = "645029162673" got, err := g.WebsocketOrderAmend(context.Background(), amend, asset.Spot) @@ -247,9 +180,6 @@ func TestWebsocketOrderAmend(t *testing.T) { require.NotEmpty(t, got) } -var getOrderStatusError = []byte(`{"header":{"response_time":"1722417357718","status":"404","channel":"spot.order_status","event":"api","client_id":"58.169.146.133-0xc0e6013600","conn_id":"8ae56147f8a55b08","trace_id":"127ac043f3a762ae88b746122aba5e3b"},"data":{"errs":{"label":"ORDER_NOT_FOUND","message":"label: ORDER_NOT_FOUND, message: Order with ID 644999648436 not found"}},"request_id":"1722417357478800700"}`) -var getOrderStatusSuccess = []byte(`{"header":{"response_time":"1722417915985","status":"200","channel":"spot.order_status","event":"api","client_id":"58.169.146.133-0xc06e7ff1e0"},"data":{"result":{"left":"0.0003","update_time":"1722417700","amount":"0.0003","create_time":"1722416858","price":"20000","finish_as":"cancelled","time_in_force":"gtc","currency_pair":"BTC_USDT","type":"limit","account":"spot","side":"buy","amend_text":"-","text":"t-1722416858697102100","status":"cancelled","iceberg":"0","filled_total":"0","id":"644999650452","fill_price":"0","update_time_ms":1722417700653,"create_time_ms":1722416858942}},"request_id":"1722417915744467800"}`) - func TestWebsocketGetOrderStatus(t *testing.T) { t.Parallel() @@ -267,18 +197,8 @@ func TestWebsocketGetOrderStatus(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) - require.NoError(t, g.UpdateTradablePairs(context.Background(), false)) - for _, a := range g.GetAssetTypes(true) { - avail, err := g.GetAvailablePairs(a) - require.NoError(t, err) - if len(avail) > 1 { - avail = avail[:1] - } - require.NoError(t, g.SetPairs(avail, a, true)) - } - require.NoError(t, g.Websocket.Connect()) - g.GetBase().API.AuthenticatedSupport = true - g.GetBase().API.AuthenticatedWebsocketSupport = true + testexch.UpdatePairsOnce(t, g) + g := GetWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes pair, err = currency.NewPairFromString("BTC_USDT") require.NoError(t, err) @@ -287,3 +207,36 @@ func TestWebsocketGetOrderStatus(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, got) } + +// GetWebsocketInstance returns a websocket instance copy for testing. +// This restricts the pairs to a single pair per asset type to reduce test time. +func GetWebsocketInstance(t *testing.T, g *Gateio) *Gateio { + t.Helper() + + cpy := new(Gateio) + cpy.SetDefaults() + gConf, err := config.GetConfig().GetExchangeConfig("GateIO") + require.NoError(t, err) + gConf.API.AuthenticatedSupport = true + gConf.API.AuthenticatedWebsocketSupport = true + gConf.API.Credentials.Key = apiKey + gConf.API.Credentials.Secret = apiSecret + + require.NoError(t, cpy.Setup(gConf), "Test instance Setup must not error") + cpy.CurrencyPairs.Load(&g.CurrencyPairs) + + for _, a := range cpy.GetAssetTypes(true) { + if a != asset.Spot { + require.NoError(t, cpy.CurrencyPairs.SetAssetEnabled(a, false)) + continue + } + avail, err := cpy.GetAvailablePairs(a) + require.NoError(t, err) + if len(avail) > 1 { + avail = avail[:1] + } + require.NoError(t, cpy.SetPairs(avail, a, true)) + } + require.NoError(t, cpy.Websocket.Connect()) + return cpy +} From f7af44c9758eb692076aa38b583c8bd66ab7492a Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 1 Aug 2024 09:33:03 +1000 Subject: [PATCH 043/138] Add tests for stream package --- exchanges/gateio/websocket_request_test.go | 23 ++++--- exchanges/stream/websocket.go | 72 +++++----------------- exchanges/stream/websocket_test.go | 43 +++++++++++++ internal/testing/websocket/mock.go | 3 +- 4 files changed, 72 insertions(+), 69 deletions(-) diff --git a/exchanges/gateio/websocket_request_test.go b/exchanges/gateio/websocket_request_test.go index 060d293099c..159cedbee28 100644 --- a/exchanges/gateio/websocket_request_test.go +++ b/exchanges/gateio/websocket_request_test.go @@ -27,7 +27,7 @@ func TestWebsocketLogin(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) testexch.UpdatePairsOnce(t, g) - g := GetWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes route, err := g.GetWebsocketRoute(asset.Spot) require.NoError(t, err) @@ -63,7 +63,7 @@ func TestWebsocketOrderPlace(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) testexch.UpdatePairsOnce(t, g) - g := GetWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes // test single order got, err := g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out}, asset.Spot) @@ -76,9 +76,6 @@ func TestWebsocketOrderPlace(t *testing.T) { require.NotEmpty(t, got) } -var orderCancelError = []byte(`{"header":{"response_time":"1722405878406","status":"400","channel":"spot.order_cancel","event":"api","client_id":"14.203.57.50-0xc1e68ac6e0","conn_id":"0378a86ff109ca9a","trace_id":"b05be4753e751dff9175215ee020b578"},"data":{"errs":{"label":"INVALID_CURRENCY_PAIR","message":"label: INVALID_CURRENCY_PAIR, message: Invalid currency pair BTCUSD"}},"request_id":"1722405878175928500"}`) -var orderCancelSuccess = []byte(`{"header":{"response_time":"1722406252471","status":"200","channel":"spot.order_cancel","event":"api","client_id":"14.203.57.50-0xc2397b9e40"},"data":{"result":{"left":"0.0003","update_time":"1722406252","amount":"0.0003","create_time":"1722406069","price":"20000","finish_as":"cancelled","time_in_force":"gtc","currency_pair":"BTC_USDT","type":"limit","account":"spot","side":"buy","amend_text":"-","text":"t-1722406069442994700","status":"cancelled","iceberg":"0","filled_total":"0","id":"644913098758","fill_price":"0","update_time_ms":1722406252467,"create_time_ms":1722406069667}},"request_id":"1722406252236528200"}`) - func TestWebsocketOrderCancel(t *testing.T) { t.Parallel() _, err := g.WebsocketOrderCancel(context.Background(), "", currency.EMPTYPAIR, "", 0) @@ -95,7 +92,7 @@ func TestWebsocketOrderCancel(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) testexch.UpdatePairsOnce(t, g) - g := GetWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes got, err := g.WebsocketOrderCancel(context.Background(), "644913098758", btcusdt, "", asset.Spot) require.NoError(t, err) @@ -120,7 +117,7 @@ func TestWebsocketOrderCancelAllByIDs(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) testexch.UpdatePairsOnce(t, g) - g := GetWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes out.OrderID = "644913101755" got, err := g.WebsocketOrderCancelAllByIDs(context.Background(), []WebsocketOrderCancelRequest{out}, asset.Spot) @@ -139,7 +136,7 @@ func TestWebsocketOrderCancelAllByPair(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) testexch.UpdatePairsOnce(t, g) - g := GetWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes got, err := g.WebsocketOrderCancelAllByPair(context.Background(), currency.EMPTYPAIR, order.Buy, "", asset.Spot) require.NoError(t, err) @@ -171,8 +168,10 @@ func TestWebsocketOrderAmend(t *testing.T) { _, err = g.WebsocketOrderAmend(context.Background(), amend, 0) require.ErrorIs(t, err, common.ErrNotYetImplemented) + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + testexch.UpdatePairsOnce(t, g) - g := GetWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes amend.OrderID = "645029162673" got, err := g.WebsocketOrderAmend(context.Background(), amend, asset.Spot) @@ -198,7 +197,7 @@ func TestWebsocketGetOrderStatus(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) testexch.UpdatePairsOnce(t, g) - g := GetWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes pair, err = currency.NewPairFromString("BTC_USDT") require.NoError(t, err) @@ -208,9 +207,9 @@ func TestWebsocketGetOrderStatus(t *testing.T) { require.NotEmpty(t, got) } -// GetWebsocketInstance returns a websocket instance copy for testing. +// getWebsocketInstance returns a websocket instance copy for testing. // This restricts the pairs to a single pair per asset type to reduce test time. -func GetWebsocketInstance(t *testing.T, g *Gateio) *Gateio { +func getWebsocketInstance(t *testing.T, g *Gateio) *Gateio { t.Helper() cpy := new(Gateio) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 8c9fdf0f7c7..d579324f37c 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -2,7 +2,6 @@ package stream import ( "context" - "encoding/json" "errors" "fmt" "net" @@ -34,6 +33,10 @@ var ( ErrNotConnected = errors.New("websocket is not connected") ErrNoMessageListener = errors.New("websocket listener not found for message") ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature") + ErrRequestRouteNotFound = errors.New("request route not found") + ErrRequestRouteNotSet = errors.New("request route not set") + ErrSignatureNotSet = errors.New("signature not set") + ErrRequestPayloadNotSet = errors.New("request payload not set") ) // Private websocket errors @@ -70,6 +73,7 @@ var ( errNoPendingConnections = errors.New("no pending connections, call SetupNewConnection first") errConnectionCandidateDuplication = errors.New("connection candidate duplication") errCannotChangeConnectionURL = errors.New("cannot change connection URL when using multi connection management") + errCannotObtainOutboundConnection = errors.New("cannot obtain outbound connection") ) var ( @@ -429,19 +433,14 @@ func (w *Websocket) Connect() error { go w.Reader(context.TODO(), conn, w.connectionManager[i].Setup.Handler) if w.connectionManager[i].Setup.Authenticate != nil && w.CanUseAuthenticatedEndpoints() { - fmt.Println("Authenticating") err = w.connectionManager[i].Setup.Authenticate(context.TODO(), conn) if err != nil { - fmt.Println("Error authenticating", err) - } else { - fmt.Println("Authenticated") + // Opted to not fail entirely here for POC. This should be + // revisited and handled more gracefully. + log.Errorf(log.WebsocketMgr, "%s websocket: [conn:%d] [URL:%s] failed to authenticate %v", w.exchangeName, i+1, conn.URL, err) } } - for _, sub := range subs { - fmt.Printf("Subscribing to %+v\n", sub) - } - err = w.connectionManager[i].Setup.Subscriber(context.TODO(), conn, subs) if err != nil { multiConnectFatalError = fmt.Errorf("%v Error subscribing %w", w.exchangeName, err) @@ -1324,56 +1323,21 @@ func drain(ch <-chan error) { } } -var ErrRequestRouteNotFound = errors.New("request route not found") -var ErrRequestRouteNotSet = errors.New("request route not set") -var ErrSignatureNotSet = errors.New("signature not set") -var ErrRequestPayloadNotSet = errors.New("request payload not set") - -// SendRequest sends a request to a specific route and unmarhsals the response -// into the result. NOTE: Only for multi connection management. -func (w *Websocket) SendRequest(ctx context.Context, routeID string, signature, payload, result any) error { - if w == nil { - return fmt.Errorf("%w: Websocket", common.ErrNilPointer) - } - - if signature == nil { - return ErrSignatureNotSet - } - if payload == nil { - return ErrRequestPayloadNotSet - } - - outbound, err := w.GetOutboundConnection(routeID) - if err != nil { - return err - } - - // if w.verbose { - display, _ := json.Marshal(payload) - log.Debugf(log.WebsocketMgr, "%s websocket: sending request to %s. Data: %v", w.exchangeName, routeID, string(display)) - // } - - resp, err := outbound.SendMessageReturnResponse(ctx, signature, payload) - if err != nil { - return err - } - - // if w.verbose { - log.Debugf(log.WebsocketMgr, "%s websocket: received response from %s. Data: %s", w.exchangeName, routeID, resp) - // } - return json.Unmarshal(resp, result) -} - -var errCannotObtainOutboundConnection = errors.New("cannot obtain outbound connection") - // GetOutboundConnection returns a connection specifically for outbound requests // for multi connection management. TODO: Upgrade routeID so that if there is // a URL change it can be handled. func (w *Websocket) GetOutboundConnection(routeID string) (Connection, error) { if w == nil { - return nil, fmt.Errorf("%w: Websocket", common.ErrNilPointer) + return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, w) + } + + if routeID == "" { + return nil, ErrRequestRouteNotSet } + w.m.Lock() + defer w.m.Unlock() + if !w.IsConnected() { return nil, ErrNotConnected } @@ -1382,10 +1346,6 @@ func (w *Websocket) GetOutboundConnection(routeID string) (Connection, error) { return nil, fmt.Errorf("%s: multi connection management not enabled %w please use exported Conn and AuthConn fields", w.exchangeName, errCannotObtainOutboundConnection) } - if routeID == "" { - return nil, ErrRequestRouteNotSet - } - for x := range w.connectionManager { if w.connectionManager[x].Setup.URL != routeID { continue diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 4e896f9591f..a2dec6e6b41 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -364,6 +364,8 @@ func TestConnectionMessageErrors(t *testing.T) { assert.ErrorIs(t, err, errNoPendingConnections, "Connect should error correctly") ws.useMultiConnectionManagement = true + ws.SetCanUseAuthenticatedEndpoints(true) + ws.verbose = true // NOTE: Intentional mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) defer mock.Close() @@ -371,6 +373,8 @@ func TestConnectionMessageErrors(t *testing.T) { err = ws.Connect() require.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset) + ws.connectionManager[0].Setup.Authenticate = func(context.Context, Connection) error { return errDastardlyReason } + ws.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) { return nil, errDastardlyReason } @@ -1452,3 +1456,42 @@ func TestCheckSubscriptions(t *testing.T) { err = ws.checkSubscriptions(nil, subscription.List{{}}) assert.NoError(t, err, "checkSubscriptions should not error") } + +func TestGetOutboundConnection(t *testing.T) { + t.Parallel() + var ws *Websocket + _, err := ws.GetOutboundConnection("") + require.ErrorIs(t, err, common.ErrNilPointer) + + ws = &Websocket{} + _, err = ws.GetOutboundConnection("") + require.ErrorIs(t, err, ErrRequestRouteNotSet) + + _, err = ws.GetOutboundConnection("testURL") + require.ErrorIs(t, err, ErrNotConnected) + + ws.setState(connectedState) + _, err = ws.GetOutboundConnection("testURL") + require.ErrorIs(t, err, errCannotObtainOutboundConnection) + + ws.useMultiConnectionManagement = true + _, err = ws.GetOutboundConnection("testURL") + require.ErrorIs(t, err, ErrRequestRouteNotFound) + + ws.connectionManager = []ConnectionWrapper{{ + Setup: &ConnectionSetup{URL: "testURL"}, + }} + + _, err = ws.GetOutboundConnection("testURL") + require.ErrorIs(t, err, ErrNotConnected) + + expected := &WebsocketConnection{} + ws.connectionManager = []ConnectionWrapper{{ + Setup: &ConnectionSetup{URL: "testURL"}, + Connection: expected, + }} + + conn, err := ws.GetOutboundConnection("testURL") + require.NoError(t, err) + assert.Same(t, expected, conn) +} diff --git a/internal/testing/websocket/mock.go b/internal/testing/websocket/mock.go index ce5f5dc0ab7..bbccf4ab2cf 100644 --- a/internal/testing/websocket/mock.go +++ b/internal/testing/websocket/mock.go @@ -36,7 +36,8 @@ func WsMockUpgrader(tb testing.TB, w http.ResponseWriter, r *http.Request, wsHan return } - if err != nil && strings.Contains(err.Error(), "wsarecv: An established connection was aborted by the software in your host machine.") { + if err != nil && (strings.Contains(err.Error(), "wsarecv: An established connection was aborted by the software in your host machine.") || + strings.Contains(err.Error(), "wsarecv: An existing connection was forcibly closed by the remote host.")) { return } From 3ed912beeff3bdad360542310f1c3a5588f79675 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 1 Aug 2024 09:40:46 +1000 Subject: [PATCH 044/138] rm unused field --- exchanges/gateio/gateio_wrapper.go | 1 - exchanges/stream/stream_types.go | 4 ---- 2 files changed, 5 deletions(-) diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 7537bb7d3bb..bbbf59a7603 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -213,7 +213,6 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Unsubscriber: g.SpotUnsubscribe, GenerateSubscriptions: g.GenerateDefaultSubscriptionsSpot, Connector: g.WsConnectSpot, - AllowOutbound: true, Authenticate: g.AuthenticateSpot, }) if err != nil { diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index 13ab4f1d98d..1dc0c1d5e8b 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -67,10 +67,6 @@ type ConnectionSetup struct { // received from the exchange's websocket server. This function should // handle the incoming message and pass it to the appropriate data handler. Handler func(ctx context.Context, incoming []byte) error - // AllowOutbound is a flag that determines if the connection is allowed to - // send messages to the exchange's websocket server. This will allow the - // connection to be established without subscriptions needing to be made. - AllowOutbound bool // Authenticate is a function that will be called to authenticate the // connection to the exchange's websocket server. This function should // handle the authentication process and return an error if the From ee8a35c42d8e1fec5070600356acd4905e23970a Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 1 Aug 2024 10:01:25 +1000 Subject: [PATCH 045/138] glorious: nits --- exchanges/stream/stream_match.go | 6 +++--- exchanges/stream/stream_match_test.go | 12 +++++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/exchanges/stream/stream_match.go b/exchanges/stream/stream_match.go index 6f934503da8..a7b8a10bfab 100644 --- a/exchanges/stream/stream_match.go +++ b/exchanges/stream/stream_match.go @@ -12,7 +12,7 @@ var ( // NewMatch returns a new Match func NewMatch() *Match { - return &Match{m: make(map[any]incoming)} + return &Match{m: make(map[any]*incoming)} } // Match is a distributed subtype that handles the matching of requests and @@ -20,7 +20,7 @@ func NewMatch() *Match { // connections. Stream systems fan in all incoming payloads to one routine for // processing. type Match struct { - m map[any]incoming + m map[any]*incoming mu sync.Mutex } @@ -58,7 +58,7 @@ func (m *Match) Set(signature any, bufSize int) (<-chan []byte, error) { return nil, errSignatureCollision } ch := make(chan []byte, bufSize) - m.m[signature] = incoming{expected: bufSize, c: ch} + m.m[signature] = &incoming{expected: bufSize, c: ch} return ch, nil } diff --git a/exchanges/stream/stream_match_test.go b/exchanges/stream/stream_match_test.go index 52acd2f95f0..583b3c4b3ed 100644 --- a/exchanges/stream/stream_match_test.go +++ b/exchanges/stream/stream_match_test.go @@ -24,11 +24,17 @@ func TestMatch(t *testing.T) { assert.True(t, match.IncomingWithData("hello", []byte("hello"))) assert.Equal(t, "hello", string(<-ch)) - _, err = match.Set("hello", 1) + _, err = match.Set("hello", 2) assert.ErrorIs(t, err, errSignatureCollision, "Should error on signature collision") assert.True(t, match.IncomingWithData("hello", load), "Should match with matching message and signature") - assert.True(t, match.IncomingWithData("hello", load), "Should match with matching message and signature") + assert.False(t, match.IncomingWithData("hello", load), "Should not match with matching message and signature") + + assert.Len(t, ch, 1, "Channel should have 1 items, 1 was already read above") - assert.Len(t, ch, 2, "Channel should have 2 items") + ch, err = match.Set("masterblaster", 1) + require.NoError(t, err) + match.RemoveSignature("masterblaster") + garbage := <-ch // Should be closed and super slipery + require.Empty(t, garbage) } From cc97528da02626221cf431da8c5b8347b435dbb7 Mon Sep 17 00:00:00 2001 From: shazbert Date: Sun, 4 Aug 2024 06:17:01 +1000 Subject: [PATCH 046/138] rn files, specifically set function names to asset and offload routing to websocket type. --- ...st.go => gateio_websocket_request_spot.go} | 79 +++++------------- ... => gateio_websocket_request_spot_test.go} | 82 +++++++------------ ...s.go => gateio_websocket_request_types.go} | 0 exchanges/gateio/gateio_wrapper.go | 67 ++++++++------- exchanges/stream/stream_types.go | 5 ++ exchanges/stream/websocket.go | 40 +++++---- exchanges/stream/websocket_test.go | 17 ++-- exchanges/stream/websocket_types.go | 8 +- 8 files changed, 130 insertions(+), 168 deletions(-) rename exchanges/gateio/{websocket_request.go => gateio_websocket_request_spot.go} (72%) rename exchanges/gateio/{websocket_request_test.go => gateio_websocket_request_spot_test.go} (63%) rename exchanges/gateio/{websocket_request_types.go => gateio_websocket_request_types.go} (100%) diff --git a/exchanges/gateio/websocket_request.go b/exchanges/gateio/gateio_websocket_request_spot.go similarity index 72% rename from exchanges/gateio/websocket_request.go rename to exchanges/gateio/gateio_websocket_request_spot.go index 734a2855564..b721c9facd6 100644 --- a/exchanges/gateio/websocket_request.go +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -26,17 +26,6 @@ var ( errChannelEmpty = errors.New("channel cannot be empty") ) -// GetWebsocketRoute returns the route for a websocket request, this is a POC -// for the websocket wrapper. -func (g *Gateio) GetWebsocketRoute(a asset.Item) (string, error) { - switch a { - case asset.Spot: - return gateioWebsocketEndpoint, nil - default: - return "", common.ErrNotYetImplemented - } -} - // WebsocketLogin authenticates the websocket connection func (g *Gateio) WebsocketLogin(ctx context.Context, conn stream.Connection, channel string) (*WebsocketLoginResponse, error) { if conn == nil { @@ -93,10 +82,10 @@ func (g *Gateio) WebsocketLogin(ctx context.Context, conn stream.Connection, cha return &result, json.Unmarshal(inbound.Data, &result) } -// WebsocketOrderPlace places an order via the websocket connection. You can +// WebsocketOrderPlaceSpot places an order via the websocket connection. You can // send multiple orders in a single request. But only for one asset route. // So this can only batch spot orders or futures orders, not both. -func (g *Gateio) WebsocketOrderPlace(ctx context.Context, batch []WebsocketOrder, a asset.Item) ([]WebsocketOrderResponse, error) { +func (g *Gateio) WebsocketOrderPlaceSpot(ctx context.Context, batch []WebsocketOrder) ([]WebsocketOrderResponse, error) { if len(batch) == 0 { return nil, errBatchSliceEmpty } @@ -123,34 +112,25 @@ func (g *Gateio) WebsocketOrderPlace(ctx context.Context, batch []WebsocketOrder } } - route, err := g.GetWebsocketRoute(a) - if err != nil { - return nil, err - } - if len(batch) == 1 { var singleResponse WebsocketOrderResponse - err = g.SendWebsocketRequest(ctx, "spot.order_place", route, batch[0], &singleResponse, 2) + err := g.SendWebsocketRequest(ctx, "spot.order_place", asset.Spot, batch[0], &singleResponse, 2) return []WebsocketOrderResponse{singleResponse}, err } var resp []WebsocketOrderResponse - err = g.SendWebsocketRequest(ctx, "spot.order_place", route, batch, &resp, 2) + err := g.SendWebsocketRequest(ctx, "spot.order_place", asset.Spot, batch, &resp, 2) return resp, err } -// WebsocketOrderCancel cancels an order via the websocket connection -func (g *Gateio) WebsocketOrderCancel(ctx context.Context, orderID string, pair currency.Pair, account string, a asset.Item) (*WebsocketOrderResponse, error) { +// WebsocketOrderCancelSpot cancels an order via the websocket connection +func (g *Gateio) WebsocketOrderCancelSpot(ctx context.Context, orderID string, pair currency.Pair, account string) (*WebsocketOrderResponse, error) { if orderID == "" { return nil, order.ErrOrderIDNotSet } if pair.IsEmpty() { return nil, currency.ErrCurrencyPairEmpty } - route, err := g.GetWebsocketRoute(a) - if err != nil { - return nil, err - } params := &struct { OrderID string `json:"order_id"` // This requires order_id tag @@ -163,12 +143,12 @@ func (g *Gateio) WebsocketOrderCancel(ctx context.Context, orderID string, pair } var resp WebsocketOrderResponse - err = g.SendWebsocketRequest(ctx, "spot.order_cancel", route, params, &resp, 1) + err := g.SendWebsocketRequest(ctx, "spot.order_cancel", asset.Spot, params, &resp, 1) return &resp, err } -// WebsocketOrderCancelAllByIDs cancels multiple orders via the websocket -func (g *Gateio) WebsocketOrderCancelAllByIDs(ctx context.Context, o []WebsocketOrderCancelRequest, a asset.Item) ([]WebsocketCancellAllResponse, error) { +// WebsocketOrderCancelAllByIDsSpots cancels multiple orders via the websocket +func (g *Gateio) WebsocketOrderCancelAllByIDsSpot(ctx context.Context, o []WebsocketOrderCancelRequest) ([]WebsocketCancellAllResponse, error) { if len(o) == 0 { return nil, errNoOrdersToCancel } @@ -182,18 +162,13 @@ func (g *Gateio) WebsocketOrderCancelAllByIDs(ctx context.Context, o []Websocket } } - route, err := g.GetWebsocketRoute(a) - if err != nil { - return nil, err - } - var resp []WebsocketCancellAllResponse - err = g.SendWebsocketRequest(ctx, "spot.order_cancel_ids", route, o, &resp, 2) + err := g.SendWebsocketRequest(ctx, "spot.order_cancel_ids", asset.Spot, o, &resp, 2) return resp, err } -// WebsocketOrderCancelAllByPair cancels all orders for a specific pair -func (g *Gateio) WebsocketOrderCancelAllByPair(ctx context.Context, pair currency.Pair, side order.Side, account string, a asset.Item) ([]WebsocketOrderResponse, error) { +// WebsocketOrderCancelAllByPairSpot cancels all orders for a specific pair +func (g *Gateio) WebsocketOrderCancelAllByPairSpot(ctx context.Context, pair currency.Pair, side order.Side, account string) ([]WebsocketOrderResponse, error) { if !pair.IsEmpty() && side == order.UnknownSide { return nil, fmt.Errorf("%w: side cannot be unknown when pair is set as this will purge *ALL* open orders", errEdgeCaseIssue) } @@ -203,11 +178,6 @@ func (g *Gateio) WebsocketOrderCancelAllByPair(ctx context.Context, pair currenc sideStr = side.Lower() } - route, err := g.GetWebsocketRoute(a) - if err != nil { - return nil, err - } - params := &WebsocketCancelParam{ Pair: pair, Side: sideStr, @@ -215,12 +185,12 @@ func (g *Gateio) WebsocketOrderCancelAllByPair(ctx context.Context, pair currenc } var resp []WebsocketOrderResponse - err = g.SendWebsocketRequest(ctx, "spot.order_cancel_cp", route, params, &resp, 1) + err := g.SendWebsocketRequest(ctx, "spot.order_cancel_cp", asset.Spot, params, &resp, 1) return resp, err } // WebsocketOrderAmend amends an order via the websocket connection -func (g *Gateio) WebsocketOrderAmend(ctx context.Context, amend *WebsocketAmendOrder, a asset.Item) (*WebsocketOrderResponse, error) { +func (g *Gateio) WebsocketOrderAmendSpot(ctx context.Context, amend *WebsocketAmendOrder) (*WebsocketOrderResponse, error) { if amend == nil { return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, amend) } @@ -237,28 +207,19 @@ func (g *Gateio) WebsocketOrderAmend(ctx context.Context, amend *WebsocketAmendO return nil, fmt.Errorf("%w: amount or price must be set", errInvalidAmount) } - route, err := g.GetWebsocketRoute(a) - if err != nil { - return nil, err - } - var resp WebsocketOrderResponse - err = g.SendWebsocketRequest(ctx, "spot.order_amend", route, amend, &resp, 1) + err := g.SendWebsocketRequest(ctx, "spot.order_amend", asset.Spot, amend, &resp, 1) return &resp, err } -// WebsocketGetOrderStatus gets the status of an order via the websocket connection -func (g *Gateio) WebsocketGetOrderStatus(ctx context.Context, orderID string, pair currency.Pair, account string, a asset.Item) (*WebsocketOrderResponse, error) { +// WebsocketGetOrderStatusSpot gets the status of an order via the websocket connection +func (g *Gateio) WebsocketGetOrderStatusSpot(ctx context.Context, orderID string, pair currency.Pair, account string) (*WebsocketOrderResponse, error) { if orderID == "" { return nil, order.ErrOrderIDNotSet } if pair.IsEmpty() { return nil, currency.ErrCurrencyPairEmpty } - route, err := g.GetWebsocketRoute(a) - if err != nil { - return nil, err - } params := &struct { OrderID string `json:"order_id"` // This requires order_id tag @@ -271,18 +232,18 @@ func (g *Gateio) WebsocketGetOrderStatus(ctx context.Context, orderID string, pa } var resp WebsocketOrderResponse - err = g.SendWebsocketRequest(ctx, "spot.order_status", route, params, &resp, 1) + err := g.SendWebsocketRequest(ctx, "spot.order_status", asset.Spot, params, &resp, 1) return &resp, err } // SendWebsocketRequest sends a websocket request to the exchange -func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel, route string, params, result any, expectedResponses int) error { +func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel string, connSignature, params, result any, expectedResponses int) error { paramPayload, err := json.Marshal(params) if err != nil { return err } - conn, err := g.Websocket.GetOutboundConnection(route) + conn, err := g.Websocket.GetOutboundConnection(connSignature) if err != nil { return err } diff --git a/exchanges/gateio/websocket_request_test.go b/exchanges/gateio/gateio_websocket_request_spot_test.go similarity index 63% rename from exchanges/gateio/websocket_request_test.go rename to exchanges/gateio/gateio_websocket_request_spot_test.go index 159cedbee28..80bbb86c88d 100644 --- a/exchanges/gateio/websocket_request_test.go +++ b/exchanges/gateio/gateio_websocket_request_spot_test.go @@ -29,10 +29,7 @@ func TestWebsocketLogin(t *testing.T) { testexch.UpdatePairsOnce(t, g) g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes - route, err := g.GetWebsocketRoute(asset.Spot) - require.NoError(t, err) - - demonstrationConn, err := g.Websocket.GetOutboundConnection(route) + demonstrationConn, err := g.Websocket.GetOutboundConnection(asset.Spot) require.NoError(t, err) got, err := g.WebsocketLogin(context.Background(), demonstrationConn, "spot.login") @@ -40,25 +37,23 @@ func TestWebsocketLogin(t *testing.T) { require.NotEmpty(t, got) } -func TestWebsocketOrderPlace(t *testing.T) { +func TestWebsocketOrderPlaceSpot(t *testing.T) { t.Parallel() - _, err := g.WebsocketOrderPlace(context.Background(), nil, 0) + _, err := g.WebsocketOrderPlaceSpot(context.Background(), nil) require.ErrorIs(t, err, errBatchSliceEmpty) - _, err = g.WebsocketOrderPlace(context.Background(), make([]WebsocketOrder, 1), 0) + _, err = g.WebsocketOrderPlaceSpot(context.Background(), make([]WebsocketOrder, 1)) require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) out := WebsocketOrder{CurrencyPair: "BTC_USDT"} - _, err = g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out}, 0) + _, err = g.WebsocketOrderPlaceSpot(context.Background(), []WebsocketOrder{out}) require.ErrorIs(t, err, order.ErrSideIsInvalid) out.Side = strings.ToLower(order.Buy.String()) - _, err = g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out}, 0) + _, err = g.WebsocketOrderPlaceSpot(context.Background(), []WebsocketOrder{out}) require.ErrorIs(t, err, errInvalidAmount) out.Amount = "0.0003" out.Type = "limit" - _, err = g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out}, 0) + _, err = g.WebsocketOrderPlaceSpot(context.Background(), []WebsocketOrder{out}) require.ErrorIs(t, err, errInvalidPrice) out.Price = "20000" - _, err = g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out}, 0) - require.ErrorIs(t, err, common.ErrNotYetImplemented) sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) @@ -66,71 +61,65 @@ func TestWebsocketOrderPlace(t *testing.T) { g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes // test single order - got, err := g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out}, asset.Spot) + got, err := g.WebsocketOrderPlaceSpot(context.Background(), []WebsocketOrder{out}) require.NoError(t, err) require.NotEmpty(t, got) // test batch orders - got, err = g.WebsocketOrderPlace(context.Background(), []WebsocketOrder{out, out}, asset.Spot) + got, err = g.WebsocketOrderPlaceSpot(context.Background(), []WebsocketOrder{out, out}) require.NoError(t, err) require.NotEmpty(t, got) } -func TestWebsocketOrderCancel(t *testing.T) { +func TestWebsocketOrderCancelSpot(t *testing.T) { t.Parallel() - _, err := g.WebsocketOrderCancel(context.Background(), "", currency.EMPTYPAIR, "", 0) + _, err := g.WebsocketOrderCancelSpot(context.Background(), "", currency.EMPTYPAIR, "") require.ErrorIs(t, err, order.ErrOrderIDNotSet) - _, err = g.WebsocketOrderCancel(context.Background(), "1337", currency.EMPTYPAIR, "", 0) + _, err = g.WebsocketOrderCancelSpot(context.Background(), "1337", currency.EMPTYPAIR, "") require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) btcusdt, err := currency.NewPairFromString("BTC_USDT") require.NoError(t, err) - _, err = g.WebsocketOrderCancel(context.Background(), "1337", btcusdt, "", 0) - require.ErrorIs(t, err, common.ErrNotYetImplemented) - sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) testexch.UpdatePairsOnce(t, g) g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes - got, err := g.WebsocketOrderCancel(context.Background(), "644913098758", btcusdt, "", asset.Spot) + got, err := g.WebsocketOrderCancelSpot(context.Background(), "644913098758", btcusdt, "") require.NoError(t, err) require.NotEmpty(t, got) } -func TestWebsocketOrderCancelAllByIDs(t *testing.T) { +func TestWebsocketOrderCancelAllByIDsSpot(t *testing.T) { t.Parallel() out := WebsocketOrderCancelRequest{} - _, err := g.WebsocketOrderCancelAllByIDs(context.Background(), []WebsocketOrderCancelRequest{out}, 0) + _, err := g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderCancelRequest{out}) require.ErrorIs(t, err, order.ErrOrderIDNotSet) out.OrderID = "1337" - _, err = g.WebsocketOrderCancelAllByIDs(context.Background(), []WebsocketOrderCancelRequest{out}, 0) + _, err = g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderCancelRequest{out}) require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) out.Pair, err = currency.NewPairFromString("BTC_USDT") require.NoError(t, err) - _, err = g.WebsocketOrderCancelAllByIDs(context.Background(), []WebsocketOrderCancelRequest{out}, 0) - require.ErrorIs(t, err, common.ErrNotYetImplemented) - sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) testexch.UpdatePairsOnce(t, g) g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes out.OrderID = "644913101755" - got, err := g.WebsocketOrderCancelAllByIDs(context.Background(), []WebsocketOrderCancelRequest{out}, asset.Spot) + got, err := g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderCancelRequest{out}) require.NoError(t, err) require.NotEmpty(t, got) } -func TestWebsocketOrderCancelAllByPair(t *testing.T) { +func TestWebsocketOrderCancelAllByPairSpot(t *testing.T) { t.Parallel() pair, err := currency.NewPairFromString("LTC_USDT") require.NoError(t, err) - _, err = g.WebsocketOrderCancelAllByPair(context.Background(), pair, 0, "", 0) + _, err = g.WebsocketOrderCancelAllByPairSpot(context.Background(), pair, 0, "") require.ErrorIs(t, err, errEdgeCaseIssue) sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) @@ -138,71 +127,62 @@ func TestWebsocketOrderCancelAllByPair(t *testing.T) { testexch.UpdatePairsOnce(t, g) g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes - got, err := g.WebsocketOrderCancelAllByPair(context.Background(), currency.EMPTYPAIR, order.Buy, "", asset.Spot) + got, err := g.WebsocketOrderCancelAllByPairSpot(context.Background(), currency.EMPTYPAIR, order.Buy, "") require.NoError(t, err) require.NotEmpty(t, got) } -func TestWebsocketOrderAmend(t *testing.T) { +func TestWebsocketOrderAmendSpot(t *testing.T) { t.Parallel() - _, err := g.WebsocketOrderAmend(context.Background(), nil, 0) + _, err := g.WebsocketOrderAmendSpot(context.Background(), nil) require.ErrorIs(t, err, common.ErrNilPointer) amend := &WebsocketAmendOrder{} - _, err = g.WebsocketOrderAmend(context.Background(), amend, 0) + _, err = g.WebsocketOrderAmendSpot(context.Background(), amend) require.ErrorIs(t, err, order.ErrOrderIDNotSet) amend.OrderID = "1337" - _, err = g.WebsocketOrderAmend(context.Background(), amend, 0) + _, err = g.WebsocketOrderAmendSpot(context.Background(), amend) require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) amend.Pair, err = currency.NewPairFromString("BTC_USDT") require.NoError(t, err) - _, err = g.WebsocketOrderAmend(context.Background(), amend, 0) + _, err = g.WebsocketOrderAmendSpot(context.Background(), amend) require.ErrorIs(t, err, errInvalidAmount) amend.Amount = "0.0004" - _, err = g.WebsocketOrderAmend(context.Background(), amend, 0) - require.ErrorIs(t, err, common.ErrNotYetImplemented) - sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) testexch.UpdatePairsOnce(t, g) g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes amend.OrderID = "645029162673" - got, err := g.WebsocketOrderAmend(context.Background(), amend, asset.Spot) + got, err := g.WebsocketOrderAmendSpot(context.Background(), amend) require.NoError(t, err) require.NotEmpty(t, got) } -func TestWebsocketGetOrderStatus(t *testing.T) { +func TestWebsocketGetOrderStatusSpot(t *testing.T) { t.Parallel() - _, err := g.WebsocketGetOrderStatus(context.Background(), "", currency.EMPTYPAIR, "", 0) + _, err := g.WebsocketGetOrderStatusSpot(context.Background(), "", currency.EMPTYPAIR, "") require.ErrorIs(t, err, order.ErrOrderIDNotSet) - _, err = g.WebsocketGetOrderStatus(context.Background(), "1337", currency.EMPTYPAIR, "", 0) + _, err = g.WebsocketGetOrderStatusSpot(context.Background(), "1337", currency.EMPTYPAIR, "") require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) - pair, err := currency.NewPairFromString("LTC_USDT") - require.NoError(t, err) - - _, err = g.WebsocketGetOrderStatus(context.Background(), "1337", pair, "", 0) - require.ErrorIs(t, err, common.ErrNotYetImplemented) - sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) testexch.UpdatePairsOnce(t, g) g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes - pair, err = currency.NewPairFromString("BTC_USDT") + pair, err := currency.NewPairFromString("BTC_USDT") require.NoError(t, err) - got, err := g.WebsocketGetOrderStatus(context.Background(), "644999650452", pair, "", asset.Spot) + got, err := g.WebsocketGetOrderStatusSpot(context.Background(), "644999650452", pair, "") require.NoError(t, err) require.NotEmpty(t, got) } diff --git a/exchanges/gateio/websocket_request_types.go b/exchanges/gateio/gateio_websocket_request_types.go similarity index 100% rename from exchanges/gateio/websocket_request_types.go rename to exchanges/gateio/gateio_websocket_request_types.go diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index bbbf59a7603..1cbbc9dfd36 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -204,16 +204,17 @@ func (g *Gateio) Setup(exch *config.Exchange) error { } // Spot connection err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ - URL: gateioWebsocketEndpoint, - RateLimit: gateioWebsocketRateLimit, - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - Handler: g.WsHandleSpotData, - Subscriber: g.SpotSubscribe, - Unsubscriber: g.SpotUnsubscribe, - GenerateSubscriptions: g.GenerateDefaultSubscriptionsSpot, - Connector: g.WsConnectSpot, - Authenticate: g.AuthenticateSpot, + URL: gateioWebsocketEndpoint, + RateLimit: gateioWebsocketRateLimit, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: g.WsHandleSpotData, + Subscriber: g.SpotSubscribe, + Unsubscriber: g.SpotUnsubscribe, + GenerateSubscriptions: g.GenerateDefaultSubscriptionsSpot, + Connector: g.WsConnectSpot, + Authenticate: g.AuthenticateSpot, + OutboundRequestSignature: asset.Spot, }) if err != nil { return err @@ -227,10 +228,11 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Handler: func(ctx context.Context, incoming []byte) error { return g.WsHandleFuturesData(ctx, incoming, asset.Futures) }, - Subscriber: g.FuturesSubscribe, - Unsubscriber: g.FuturesUnsubscribe, - GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.USDT) }, - Connector: g.WsFuturesConnect, + Subscriber: g.FuturesSubscribe, + Unsubscriber: g.FuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.USDT) }, + Connector: g.WsFuturesConnect, + OutboundRequestSignature: asset.USDTMarginedFutures, }) if err != nil { return err @@ -245,10 +247,11 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Handler: func(ctx context.Context, incoming []byte) error { return g.WsHandleFuturesData(ctx, incoming, asset.Futures) }, - Subscriber: g.FuturesSubscribe, - Unsubscriber: g.FuturesUnsubscribe, - GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.BTC) }, - Connector: g.WsFuturesConnect, + Subscriber: g.FuturesSubscribe, + Unsubscriber: g.FuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.BTC) }, + Connector: g.WsFuturesConnect, + OutboundRequestSignature: asset.CoinMarginedFutures, }) if err != nil { return err @@ -264,10 +267,11 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Handler: func(ctx context.Context, incoming []byte) error { return g.WsHandleFuturesData(ctx, incoming, asset.DeliveryFutures) }, - Subscriber: g.DeliveryFuturesSubscribe, - Unsubscriber: g.DeliveryFuturesUnsubscribe, - GenerateSubscriptions: g.GenerateDeliveryFuturesDefaultSubscriptions, - Connector: g.WsDeliveryFuturesConnect, + Subscriber: g.DeliveryFuturesSubscribe, + Unsubscriber: g.DeliveryFuturesUnsubscribe, + GenerateSubscriptions: g.GenerateDeliveryFuturesDefaultSubscriptions, + Connector: g.WsDeliveryFuturesConnect, + OutboundRequestSignature: asset.DeliveryFutures, }) if err != nil { return err @@ -275,15 +279,16 @@ func (g *Gateio) Setup(exch *config.Exchange) error { // Futures connection - Options return g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ - URL: optionsWebsocketURL, - RateLimit: gateioWebsocketRateLimit, - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - Handler: g.WsHandleOptionsData, - Subscriber: g.OptionsSubscribe, - Unsubscriber: g.OptionsUnsubscribe, - GenerateSubscriptions: g.GenerateOptionsDefaultSubscriptions, - Connector: g.WsOptionsConnect, + URL: optionsWebsocketURL, + RateLimit: gateioWebsocketRateLimit, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: g.WsHandleOptionsData, + Subscriber: g.OptionsSubscribe, + Unsubscriber: g.OptionsUnsubscribe, + GenerateSubscriptions: g.GenerateOptionsDefaultSubscriptions, + Connector: g.WsOptionsConnect, + OutboundRequestSignature: asset.Options, }) } diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index 1dc0c1d5e8b..2263b12bcfd 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -72,6 +72,11 @@ type ConnectionSetup struct { // handle the authentication process and return an error if the // authentication fails. Authenticate func(ctx context.Context, conn Connection) error + // OutboundRequestSignature is any type that will match outbound + // requests to this specific connection. This could be an asset type + // `asset.Spot`, a string type denoting the individual URL, an + // authenticated or unauthenticated string or a mixture of these. + OutboundRequestSignature any } // ConnectionWrapper contains the connection setup details to be used when diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index d579324f37c..da8837693a5 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -71,7 +71,7 @@ var ( errAlreadyReconnecting = errors.New("websocket in the process of reconnection") errConnSetup = errors.New("error in connection setup") errNoPendingConnections = errors.New("no pending connections, call SetupNewConnection first") - errConnectionCandidateDuplication = errors.New("connection candidate duplication") + errConnectionWrapperDuplication = errors.New("connection wrapper duplication") errCannotChangeConnectionURL = errors.New("cannot change connection URL when using multi connection management") errCannotObtainOutboundConnection = errors.New("cannot obtain outbound connection") ) @@ -104,6 +104,7 @@ func NewWebsocket() *Websocket { features: &protocol.Features{}, Orderbook: buffer.Orderbook{}, connections: make(map[Connection]*ConnectionWrapper), + outbound: make(map[any]*ConnectionWrapper), } } @@ -263,15 +264,20 @@ func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { } for x := range w.connectionManager { - if w.connectionManager[x].Setup.URL == c.URL { - return fmt.Errorf("%w: %w", errConnSetup, errConnectionCandidateDuplication) + // Below allows for multiple connections to the same URL with different + // outbound request signatures. This allows for easier determination of + // inbound and outbound messages. e.g. Gateio cross_margin, margin on + // a spot connection. + if w.connectionManager[x].Setup.URL == c.URL && c.OutboundRequestSignature == w.connectionManager[x].Setup.OutboundRequestSignature { + return fmt.Errorf("%w: %w", errConnSetup, errConnectionWrapperDuplication) } } - w.connectionManager = append(w.connectionManager, ConnectionWrapper{ + w.connectionManager = append(w.connectionManager, &ConnectionWrapper{ Setup: c, Subscriptions: subscription.NewStore(), }) + w.outbound[c.OutboundRequestSignature] = w.connectionManager[len(w.connectionManager)-1] return nil } @@ -421,7 +427,7 @@ func (w *Websocket) Connect() error { break } - w.connections[conn] = &w.connectionManager[i] + w.connections[conn] = w.connectionManager[i] w.connectionManager[i].Connection = conn if !conn.IsConnected() { @@ -1324,14 +1330,13 @@ func drain(ch <-chan error) { } // GetOutboundConnection returns a connection specifically for outbound requests -// for multi connection management. TODO: Upgrade routeID so that if there is -// a URL change it can be handled. -func (w *Websocket) GetOutboundConnection(routeID string) (Connection, error) { +// for multi connection management. +func (w *Websocket) GetOutboundConnection(connSignature any) (Connection, error) { if w == nil { return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, w) } - if routeID == "" { + if connSignature == "" { return nil, ErrRequestRouteNotSet } @@ -1346,15 +1351,14 @@ func (w *Websocket) GetOutboundConnection(routeID string) (Connection, error) { return nil, fmt.Errorf("%s: multi connection management not enabled %w please use exported Conn and AuthConn fields", w.exchangeName, errCannotObtainOutboundConnection) } - for x := range w.connectionManager { - if w.connectionManager[x].Setup.URL != routeID { - continue - } - if w.connectionManager[x].Connection == nil { - return nil, fmt.Errorf("%s: %w", w.connectionManager[x].Setup.URL, ErrNotConnected) - } - return w.connectionManager[x].Connection, nil + wrapper, ok := w.outbound[connSignature] + if !ok { + return nil, fmt.Errorf("%s: %w: %v", w.exchangeName, ErrRequestRouteNotFound, connSignature) + } + + if wrapper.Connection == nil { + return nil, fmt.Errorf("%s: %s %w: %v", w.exchangeName, wrapper.Setup.URL, ErrNotConnected, connSignature) } - return nil, fmt.Errorf("%w: %s", ErrRequestRouteNotFound, routeID) + return wrapper.Connection, nil } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index a2dec6e6b41..58f4e84f6cc 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -369,7 +369,7 @@ func TestConnectionMessageErrors(t *testing.T) { mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) defer mock.Close() - ws.connectionManager = []ConnectionWrapper{{Setup: &ConnectionSetup{URL: "ws" + mock.URL[len("http"):] + "/ws"}}} + ws.connectionManager = []*ConnectionWrapper{{Setup: &ConnectionSetup{URL: "ws" + mock.URL[len("http"):] + "/ws"}}} err = ws.Connect() require.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset) @@ -613,7 +613,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { amazingConn := multi.getConnectionFromSetup(amazingCandidate) multi.connections = map[Connection]*ConnectionWrapper{ - amazingConn: &multi.connectionManager[0], + amazingConn: multi.connectionManager[0], } subs, err = amazingCandidate.GenerateSubscriptions() @@ -1357,7 +1357,7 @@ func TestSetupNewConnection(t *testing.T) { require.Nil(t, multi.Conn) err = multi.SetupNewConnection(connSetup) - require.ErrorIs(t, err, errConnectionCandidateDuplication) + require.ErrorIs(t, err, errConnectionWrapperDuplication) } func TestWebsocketConnectionShutdown(t *testing.T) { @@ -1478,18 +1478,19 @@ func TestGetOutboundConnection(t *testing.T) { _, err = ws.GetOutboundConnection("testURL") require.ErrorIs(t, err, ErrRequestRouteNotFound) - ws.connectionManager = []ConnectionWrapper{{ + ws.connectionManager = []*ConnectionWrapper{{ Setup: &ConnectionSetup{URL: "testURL"}, }} + ws.outbound = map[any]*ConnectionWrapper{ + "testURL": ws.connectionManager[0], + } + _, err = ws.GetOutboundConnection("testURL") require.ErrorIs(t, err, ErrNotConnected) expected := &WebsocketConnection{} - ws.connectionManager = []ConnectionWrapper{{ - Setup: &ConnectionSetup{URL: "testURL"}, - Connection: expected, - }} + ws.connectionManager[0].Connection = expected conn, err := ws.GetOutboundConnection("testURL") require.NoError(t, err) diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index f663cb7c0cd..d736d67320a 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -50,8 +50,14 @@ type Websocket struct { m sync.Mutex connector func() error - connectionManager []ConnectionWrapper + connectionManager []*ConnectionWrapper connections map[Connection]*ConnectionWrapper + // outbound is map holding wrapper specific signatures to an active + // connection for outbound messaging. Wrapper specific connections + // might be asset specific e.g. spot, margin, futures or + // authenticated/unauthenticated or a mix of both. This map is used + // to send messages to the correct connection. + outbound map[any]*ConnectionWrapper subscriptions *subscription.Store From 66a8778d5a5dde0d0ce32b777f6ad2174906058b Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Tue, 6 Aug 2024 10:11:22 +1000 Subject: [PATCH 047/138] linter: fix --- exchanges/gateio/gateio_websocket_request_spot.go | 8 ++++---- exchanges/gateio/gateio_websocket_request_types.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/exchanges/gateio/gateio_websocket_request_spot.go b/exchanges/gateio/gateio_websocket_request_spot.go index b721c9facd6..c66630380c0 100644 --- a/exchanges/gateio/gateio_websocket_request_spot.go +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -44,7 +44,7 @@ func (g *Gateio) WebsocketLogin(ctx context.Context, conn stream.Connection, cha tn := time.Now() msg := "api\n" + channel + "\n" + "\n" + strconv.FormatInt(tn.Unix(), 10) mac := hmac.New(sha512.New, []byte(creds.Secret)) - if _, err := mac.Write([]byte(msg)); err != nil { + if _, err = mac.Write([]byte(msg)); err != nil { return nil, err } signature := hex.EncodeToString(mac.Sum(nil)) @@ -147,7 +147,7 @@ func (g *Gateio) WebsocketOrderCancelSpot(ctx context.Context, orderID string, p return &resp, err } -// WebsocketOrderCancelAllByIDsSpots cancels multiple orders via the websocket +// WebsocketOrderCancelAllByIDsSpot cancels multiple orders via the websocket func (g *Gateio) WebsocketOrderCancelAllByIDsSpot(ctx context.Context, o []WebsocketOrderCancelRequest) ([]WebsocketCancellAllResponse, error) { if len(o) == 0 { return nil, errNoOrdersToCancel @@ -189,7 +189,7 @@ func (g *Gateio) WebsocketOrderCancelAllByPairSpot(ctx context.Context, pair cur return resp, err } -// WebsocketOrderAmend amends an order via the websocket connection +// WebsocketOrderAmendSpot amends an order via the websocket connection func (g *Gateio) WebsocketOrderAmendSpot(ctx context.Context, amend *WebsocketAmendOrder) (*WebsocketOrderResponse, error) { if amend == nil { return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, amend) @@ -268,7 +268,7 @@ func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel string, connS } if len(responses) == 0 { - return fmt.Errorf("no responses received") + return errors.New("no responses received") } var inbound WebsocketAPIResponse diff --git a/exchanges/gateio/gateio_websocket_request_types.go b/exchanges/gateio/gateio_websocket_request_types.go index d662a3aa912..eb43077806d 100644 --- a/exchanges/gateio/gateio_websocket_request_types.go +++ b/exchanges/gateio/gateio_websocket_request_types.go @@ -119,7 +119,7 @@ type WebsocketOrderCancelRequest struct { Account string `json:"account,omitempty"` } -// WebsocketOrderCancelResponse defines a websocket order cancel response +// WebsocketCancellAllResponse defines a websocket order cancel response type WebsocketCancellAllResponse struct { Pair currency.Pair `json:"currency_pair"` Label string `json:"label"` From 431c0471cdd7e29622a72696c008214d8a12b802 Mon Sep 17 00:00:00 2001 From: shazbert Date: Thu, 15 Aug 2024 10:38:10 +1000 Subject: [PATCH 048/138] glorious: nits --- exchanges/stream/websocket.go | 6 +++--- exchanges/stream/websocket_connection.go | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 84210582f02..b1c8848228d 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -414,14 +414,14 @@ func (w *Websocket) Connect() error { break } - w.connections[conn] = &w.connectionManager[i] - w.connectionManager[i].Connection = conn - if !conn.IsConnected() { multiConnectFatalError = fmt.Errorf("%s websocket: [conn:%d] [URL:%s] failed to connect", w.exchangeName, i+1, conn.URL) break } + w.connections[conn] = &w.connectionManager[i] + w.connectionManager[i].Connection = conn + w.Wg.Add(1) go w.Reader(context.TODO(), conn, w.connectionManager[i].Setup.Handler) diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index b421399dbf3..eb76faf6d88 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -205,12 +205,12 @@ func (w *WebsocketConnection) IsConnected() bool { func (w *WebsocketConnection) ReadMessage() Response { mType, resp, err := w.Connection.ReadMessage() if err != nil { - // Any error condition will return a Response{Raw: nil, Type: 0} which - // will force the reader routine to return. The connection will hang - // with no reader routine and its buffer will be written to from the - // active websocket connection. This should be handed over to - // `w.readMessageErrors` and managed by 'connectionMonitor' which needs - // to flush, reconnect and resubscribe the connection. + // If any error occurs, a Response{Raw: nil, Type: 0} is returned, causing the + // reader routine to exit. This leaves the connection without an active reader, + // leading to potential buffer issue from the ongoing websocket writes. + // Such errors are passed to `w.readMessageErrors` when the connection is active. + // The `connectionMonitor` handles these errors by flushing the buffer, reconnecting, + // and resubscribing to the websocket to restore the connection. if w.setConnectedStatus(false) { // NOTE: When w.setConnectedStatus() returns true the underlying // state was changed and this infers that the connection was From 7110e888e8c218155e50826319b6a3d74f849764 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 16 Aug 2024 12:02:50 +1000 Subject: [PATCH 049/138] add counter and update gateio --- common/common.go | 17 +++++++++++++++++ common/common_test.go | 15 +++++++++++++++ exchanges/gateio/gateio.go | 1 + exchanges/gateio/gateio_websocket.go | 2 +- exchanges/gateio/gateio_ws_delivery_futures.go | 6 +++--- exchanges/gateio/gateio_ws_futures.go | 6 +++--- exchanges/gateio/gateio_ws_option.go | 4 ++-- 7 files changed, 42 insertions(+), 9 deletions(-) diff --git a/common/common.go b/common/common.go index 2aec8fc8741..7ad4305c9a4 100644 --- a/common/common.go +++ b/common/common.go @@ -19,6 +19,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "unicode" @@ -671,3 +672,19 @@ func SortStrings[S ~[]E, E fmt.Stringer](x S) S { }) return n } + +// Counter is a thread-safe counter. +type Counter struct { + n int64 // privatised so you can't use counter as a value type +} + +// IncrementAndGet returns the next count after incrementing. +func (c *Counter) IncrementAndGet() int64 { + newID := atomic.AddInt64(&c.n, 1) + // Handle overflow by resetting the counter to 1 if it becomes negative + if newID < 0 { + atomic.StoreInt64(&c.n, 1) + return 1 + } + return newID +} diff --git a/common/common_test.go b/common/common_test.go index 93ebfe7a141..5382eab2c61 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -869,3 +869,18 @@ func (a A) String() string { func TestSortStrings(t *testing.T) { assert.Equal(t, []A{1, 2, 5, 6}, SortStrings([]A{6, 2, 5, 1})) } + +func TestCounter(t *testing.T) { + t.Parallel() + c := Counter{n: -5} + require.Equal(t, int64(1), c.IncrementAndGet()) + require.Equal(t, int64(2), c.IncrementAndGet()) +} + +// 683185328 1.787 ns/op 0 B/op 0 allocs/op +func BenchmarkCounter(b *testing.B) { + c := Counter{} + for i := 0; i < b.N; i++ { + c.IncrementAndGet() + } +} diff --git a/exchanges/gateio/gateio.go b/exchanges/gateio/gateio.go index 05efc142e32..0e2a870dba2 100644 --- a/exchanges/gateio/gateio.go +++ b/exchanges/gateio/gateio.go @@ -174,6 +174,7 @@ var ( // Gateio is the overarching type across this package type Gateio struct { exchange.Base + common.Counter } // ***************************************** SubAccounts ******************************** diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 1480510f772..1b9450c360d 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -819,7 +819,7 @@ func (g *Gateio) generatePayload(event string, channelsToSubscribe subscription. } payload := WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), + ID: g.Counter.IncrementAndGet(), Event: event, Channel: channelsToSubscribe[i].Channel, Payload: params, diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_ws_delivery_futures.go index b9242981033..1e8812976e7 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_ws_delivery_futures.go @@ -84,7 +84,7 @@ func (g *Gateio) WsDeliveryFuturesConnect() error { g.Websocket.GetWebsocketURL()) } pingMessage, err := json.Marshal(WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), + ID: g.Counter.IncrementAndGet(), Time: time.Now().Unix(), Channel: futuresPingChannel, }) @@ -316,7 +316,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscrib } if strings.HasPrefix(channelsToSubscribe[i].Pairs[0].Quote.Upper().String(), "USDT") { payloads[0] = append(payloads[0], WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), + ID: g.Counter.IncrementAndGet(), Event: event, Channel: channelsToSubscribe[i].Channel, Payload: params, @@ -325,7 +325,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscrib }) } else { payloads[1] = append(payloads[1], WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), + ID: g.Counter.IncrementAndGet(), Event: event, Channel: channelsToSubscribe[i].Channel, Payload: params, diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index 19a640d227c..09311f33740 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -103,7 +103,7 @@ func (g *Gateio) WsFuturesConnect() error { g.Websocket.GetWebsocketURL()) } pingMessage, err := json.Marshal(WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), + ID: g.Counter.IncrementAndGet(), Time: func() int64 { return time.Now().Unix() }(), @@ -401,7 +401,7 @@ func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe subscr } if strings.HasPrefix(channelsToSubscribe[i].Pairs[0].Quote.Upper().String(), "USDT") { payloads[0] = append(payloads[0], WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), + ID: g.Counter.IncrementAndGet(), Event: event, Channel: channelsToSubscribe[i].Channel, Payload: params, @@ -410,7 +410,7 @@ func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe subscr }) } else { payloads[1] = append(payloads[1], WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), + ID: g.Counter.IncrementAndGet(), Event: event, Channel: channelsToSubscribe[i].Channel, Payload: params, diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_ws_option.go index c33a8493e6c..a492ae966ad 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_ws_option.go @@ -86,7 +86,7 @@ func (g *Gateio) WsOptionsConnect() error { return err } pingMessage, err := json.Marshal(WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), + ID: g.Counter.IncrementAndGet(), Time: time.Now().Unix(), Channel: optionsPingChannel, }) @@ -275,7 +275,7 @@ func (g *Gateio) generateOptionsPayload(event string, channelsToSubscribe subscr params...) } payloads[i] = WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), + ID: g.Counter.IncrementAndGet(), Event: event, Channel: channelsToSubscribe[i].Channel, Payload: params, From b70296c667772802544f29dbdbf71ce3fdf67270 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 16 Aug 2024 12:20:30 +1000 Subject: [PATCH 050/138] fix collision issue --- exchanges/gateio/gateio_websocket_request_spot.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/exchanges/gateio/gateio_websocket_request_spot.go b/exchanges/gateio/gateio_websocket_request_spot.go index c66630380c0..1d688b23aeb 100644 --- a/exchanges/gateio/gateio_websocket_request_spot.go +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -50,7 +50,7 @@ func (g *Gateio) WebsocketLogin(ctx context.Context, conn stream.Connection, cha signature := hex.EncodeToString(mac.Sum(nil)) payload := WebsocketPayload{ - RequestID: strconv.FormatInt(tn.UnixNano(), 10), + RequestID: strconv.FormatInt(g.Counter.IncrementAndGet(), 10), APIKey: creds.Key, Signature: signature, Timestamp: strconv.FormatInt(tn.Unix(), 10), @@ -94,9 +94,7 @@ func (g *Gateio) WebsocketOrderPlaceSpot(ctx context.Context, batch []WebsocketO if batch[i].Text == "" { // For some reason the API requires a text field, or it will be // rejected in the second response. This is a workaround. - // +1 index for uniqueness in batch, when clock hasn't updated yet. - // TODO: Remove and use common counter. - batch[i].Text = "t-" + strconv.FormatInt(time.Now().UnixNano()+int64(i), 10) + batch[i].Text = "t-" + strconv.FormatInt(g.Counter.IncrementAndGet(), 10) } if batch[i].CurrencyPair == "" { return nil, currency.ErrCurrencyPairEmpty @@ -256,7 +254,7 @@ func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel string, connS Payload: WebsocketPayload{ // This request ID associated with the payload is the match to the // response. - RequestID: strconv.FormatInt(tn.UnixNano(), 10), + RequestID: strconv.FormatInt(g.Counter.IncrementAndGet(), 10), RequestParam: paramPayload, Timestamp: strconv.FormatInt(tn.Unix(), 10), }, From 0402bc787a99aa375c1ada6f443e59a8e1cfacc3 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 30 Aug 2024 14:03:00 +1000 Subject: [PATCH 051/138] Update exchanges/stream/websocket.go Co-authored-by: Scott --- exchanges/stream/websocket.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index fbda484646f..3635c6f9517 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -906,7 +906,7 @@ func (w *Websocket) CanUseAuthenticatedWebsocketForWrapper() bool { // SetWebsocketURL sets websocket URL and can refresh underlying connections func (w *Websocket) SetWebsocketURL(url string, auth, reconnect bool) error { if w.useMultiConnectionManagement { - // TODO: Enable multi connection management to change URL + // TODO: Add functionality for multi-connection management to change URL return fmt.Errorf("%s: %w", w.exchangeName, errCannotChangeConnectionURL) } defaultVals := url == "" || url == config.WebsocketURLNonDefaultMessage From 0dfda95a4225b3d1a1f72b75cc0bfe73b841a6d7 Mon Sep 17 00:00:00 2001 From: shazbert Date: Fri, 30 Aug 2024 16:02:35 +1000 Subject: [PATCH 052/138] glorious: nits --- exchanges/stream/websocket.go | 159 +++++++++++----------------- exchanges/stream/websocket_test.go | 9 -- exchanges/stream/websocket_types.go | 8 +- internal/testing/websocket/mock.go | 3 +- 4 files changed, 69 insertions(+), 110 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 3635c6f9517..7feef1259a8 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -4,12 +4,10 @@ import ( "context" "errors" "fmt" - "net" "net/url" "slices" "time" - "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" @@ -627,9 +625,11 @@ func (w *Websocket) Shutdown() error { defer w.Orderbook.FlushBuffer() // Shutdown managed connections - for conn := range w.connections { - if err := conn.Shutdown(); err != nil { - return err + for _, wrapper := range w.connectionManager { + if wrapper.Connection != nil { + if err := wrapper.Connection.Shutdown(); err != nil { + return err + } } } // Clean map of old connections @@ -679,48 +679,14 @@ func (w *Websocket) FlushChannels() error { } if w.features.Subscribe { - if w.GenerateSubs != nil { - newsubs, err := w.GenerateSubs() - if err != nil { - return err - } - - subs, unsubs := w.GetChannelDifference(nil, newsubs) - if len(unsubs) != 0 && w.features.Unsubscribe { - err := w.UnsubscribeChannels(nil, unsubs) - if err != nil { - return err - } - } - if len(subs) < 1 { - return nil - } - return w.SubscribeToChannels(nil, subs) + if !w.useMultiConnectionManagement { + return w.generateUnsubscribeAndSubscribe(nil, w.GenerateSubs) } for x := range w.connectionManager { - if w.connectionManager[x].Setup.GenerateSubscriptions == nil { - continue - } - newsubs, err := w.connectionManager[x].Setup.GenerateSubscriptions() - if err != nil { - if errors.Is(err, asset.ErrNotEnabled) { - continue - } + err := w.generateUnsubscribeAndSubscribe(w.connectionManager[x].Connection, w.connectionManager[x].Setup.GenerateSubscriptions) + if err != nil && !errors.Is(err, asset.ErrNotEnabled) { return err } - subs, unsubs := w.GetChannelDifference(w.connectionManager[x].Connection, newsubs) - if len(unsubs) != 0 && w.features.Unsubscribe { - err = w.UnsubscribeChannels(w.connectionManager[x].Connection, unsubs) - if err != nil { - return err - } - } - if len(subs) != 0 { - err = w.SubscribeToChannels(w.connectionManager[x].Connection, subs) - if err != nil { - return err - } - } } return nil } else if w.features.FullPayloadSubscribe { @@ -730,39 +696,15 @@ func (w *Websocket) FlushChannels() error { // would need to send ticker, orderbook and trades channel subscription // messages. - if w.GenerateSubs != nil { - newsubs, err := w.GenerateSubs() - if err != nil { - return err - } - - if len(newsubs) != 0 { - // Purge subscription list as there will be conflicts - w.subscriptions.Clear() - return w.SubscribeToChannels(nil, newsubs) - } - return nil + if !w.useMultiConnectionManagement { + return w.generateAndSubscribe(w.subscriptions, nil, w.GenerateSubs) } for x := range w.connectionManager { - if w.connectionManager[x].Setup.GenerateSubscriptions == nil { - continue - } - newsubs, err := w.connectionManager[x].Setup.GenerateSubscriptions() - if err != nil { - if errors.Is(err, asset.ErrNotEnabled) { - continue - } + err := w.generateAndSubscribe(w.connectionManager[x].Subscriptions, w.connectionManager[x].Connection, w.connectionManager[x].Setup.GenerateSubscriptions) + if err != nil && errors.Is(err, asset.ErrNotEnabled) { return err } - if len(newsubs) != 0 { - // Purge subscription list as there will be conflicts - w.connectionManager[x].Subscriptions.Clear() - err = w.SubscribeToChannels(w.connectionManager[x].Connection, newsubs) - if err != nil { - return err - } - } } return nil } @@ -773,6 +715,36 @@ func (w *Websocket) FlushChannels() error { return w.Connect() } +func (w *Websocket) generateUnsubscribeAndSubscribe(conn Connection, generate func() (subscription.List, error)) error { + newsubs, err := generate() + if err != nil { + return err + } + subs, unsubs := w.GetChannelDifference(conn, newsubs) + if len(unsubs) != 0 && w.features.Unsubscribe { + if err = w.UnsubscribeChannels(conn, unsubs); err != nil { + return err + } + } + if len(subs) == 0 { + return nil + } + return w.SubscribeToChannels(conn, subs) +} + +func (w *Websocket) generateAndSubscribe(store *subscription.Store, conn Connection, generate func() (subscription.List, error)) error { + newsubs, err := generate() + if err != nil { + return err + } + if len(newsubs) == 0 { + return nil + } + // Purge subscription list as there will be conflicts + store.Clear() + return w.SubscribeToChannels(conn, newsubs) +} + // trafficMonitor waits trafficCheckInterval before checking for a trafficAlert // 1 slot buffer means that connection will only write to trafficAlert once per trafficCheckInterval to avoid read/write flood in high traffic // Otherwise we Shutdown the connection after trafficTimeout, unless it's connecting. connectionMonitor is responsible for Connecting again @@ -979,8 +951,10 @@ func (w *Websocket) SetProxyAddress(proxyAddr string) error { log.Debugf(log.ExchangeSys, "%s websocket: removing websocket proxy", w.exchangeName) } - for conn := range w.connections { - conn.SetProxy(proxyAddr) + for _, wrapper := range w.connectionManager { + if wrapper.Connection != nil { + wrapper.Connection.SetProxy(proxyAddr) + } } if w.Conn != nil { w.Conn.SetProxy(proxyAddr) @@ -1034,28 +1008,26 @@ func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.L if len(channels) == 0 { return nil // No channels to unsubscribe from is not an error } - if candidate, ok := w.connections[conn]; ok { - if candidate.Subscriptions == nil { - return nil // No channels to unsubscribe from is not an error - } - for _, s := range channels { - if candidate.Subscriptions.Get(s) == nil { - return fmt.Errorf("%w: %s", subscription.ErrNotFound, s) - } - } - return candidate.Setup.Unsubscriber(context.TODO(), conn, channels) + return w.unsubscribe(candidate.Subscriptions, channels, func(channels subscription.List) error { + return candidate.Setup.Unsubscriber(context.TODO(), conn, channels) + }) } + return w.unsubscribe(w.subscriptions, channels, func(channels subscription.List) error { + return w.Unsubscriber(channels) + }) +} - if w.subscriptions == nil { +func (w *Websocket) unsubscribe(store *subscription.Store, channels subscription.List, unsub func(channels subscription.List) error) error { + if store == nil { return nil // No channels to unsubscribe from is not an error } for _, s := range channels { - if w.subscriptions.Get(s) == nil { + if store.Get(s) == nil { return fmt.Errorf("%w: %s", subscription.ErrNotFound, s) } } - return w.Unsubscriber(channels) + return unsub(channels) } // ResubscribeToChannel resubscribes to channel @@ -1191,7 +1163,7 @@ func (w *Websocket) GetSubscription(key any) *subscription.Subscription { if w == nil || key == nil { return nil } - for _, c := range w.connections { + for _, c := range w.connectionManager { if c.Subscriptions == nil { continue } @@ -1212,7 +1184,7 @@ func (w *Websocket) GetSubscriptions() subscription.List { return nil } var subs subscription.List - for _, c := range w.connections { + for _, c := range w.connectionManager { if c.Subscriptions != nil { subs = append(subs, c.Subscriptions.List()...) } @@ -1233,17 +1205,6 @@ func (w *Websocket) CanUseAuthenticatedEndpoints() bool { return w.canUseAuthenticatedEndpoints.Load() } -// IsDisconnectionError Determines if the error sent over chan ReadMessageErrors is a disconnection error -func IsDisconnectionError(err error) bool { - if websocket.IsUnexpectedCloseError(err) { - return true - } - if _, ok := err.(*net.OpError); ok { - return !errors.Is(err, errClosedConnection) - } - return false -} - // checkWebsocketURL checks for a valid websocket url func checkWebsocketURL(s string) error { u, err := url.Parse(s) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index ad10485c189..e14eaacb0f3 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -8,7 +8,6 @@ import ( "encoding/json" "errors" "fmt" - "net" "net/http" "net/http/httptest" "os" @@ -292,14 +291,6 @@ func TestTrafficMonitorShutdown(t *testing.T) { } } -func TestIsDisconnectionError(t *testing.T) { - t.Parallel() - assert.False(t, IsDisconnectionError(errors.New("errorText")), "IsDisconnectionError should return false") - assert.True(t, IsDisconnectionError(&websocket.CloseError{Code: 1006, Text: "errorText"}), "IsDisconnectionError should return true") - assert.False(t, IsDisconnectionError(&net.OpError{Err: errClosedConnection}), "IsDisconnectionError should return false") - assert.True(t, IsDisconnectionError(&net.OpError{Err: errors.New("errText")}), "IsDisconnectionError should return true") -} - func TestConnectionMessageErrors(t *testing.T) { t.Parallel() var wsWrong = &Websocket{} diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index e00406bf925..5ae3d074f28 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -50,8 +50,14 @@ type Websocket struct { m sync.Mutex connector func() error + // connectionManager stores all *potential* connections for the exchange, organised within ConnectionWrapper structs. + // Each ConnectionWrapper one connection (will be expanded soon) tailored for specific exchange functionalities or asset types. // TODO: Expand this to support multiple connections per ConnectionWrapper + // For example, separate connections can be used for Spot, Margin, and Futures trading. This structure is especially useful + // for exchanges that differentiate between trading pairs by using different connection endpoints or protocols for various asset classes. + // If an exchange does not require such differentiation, all connections may be managed under a single ConnectionWrapper. connectionManager []ConnectionWrapper - connections map[Connection]*ConnectionWrapper + // connections holds a look up table for all connections to their corresponding ConnectionWrapper and subscription holder + connections map[Connection]*ConnectionWrapper subscriptions *subscription.Store diff --git a/internal/testing/websocket/mock.go b/internal/testing/websocket/mock.go index ce5f5dc0ab7..bbccf4ab2cf 100644 --- a/internal/testing/websocket/mock.go +++ b/internal/testing/websocket/mock.go @@ -36,7 +36,8 @@ func WsMockUpgrader(tb testing.TB, w http.ResponseWriter, r *http.Request, wsHan return } - if err != nil && strings.Contains(err.Error(), "wsarecv: An established connection was aborted by the software in your host machine.") { + if err != nil && (strings.Contains(err.Error(), "wsarecv: An established connection was aborted by the software in your host machine.") || + strings.Contains(err.Error(), "wsarecv: An existing connection was forcibly closed by the remote host.")) { return } From c58834e3d06ab5a33292589a78f944b48675ed78 Mon Sep 17 00:00:00 2001 From: shazbert Date: Fri, 30 Aug 2024 16:18:06 +1000 Subject: [PATCH 053/138] add tests --- exchanges/stream/websocket.go | 3 +-- exchanges/stream/websocket_test.go | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 7feef1259a8..a263f0f7c09 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -740,8 +740,7 @@ func (w *Websocket) generateAndSubscribe(store *subscription.Store, conn Connect if len(newsubs) == 0 { return nil } - // Purge subscription list as there will be conflicts - store.Clear() + store.Clear() // Purge subscription list as there will be conflicts return w.SubscribeToChannels(conn, newsubs) } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index e14eaacb0f3..e202993abde 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -1256,6 +1256,28 @@ func TestFlushChannels(t *testing.T) { w.features.Unsubscribe = true err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") + + // Multi connection management + w.useMultiConnectionManagement = true + w.exchangeName = "multi" + amazingCandidate := &ConnectionSetup{ + URL: "AMAZING", + Connector: func(context.Context, Connection) error { return nil }, + GenerateSubscriptions: newgen.generateSubs, + Subscriber: func(ctx context.Context, c Connection, s subscription.List) error { + return currySimpleSubConn(w)(ctx, c, s) + }, + Unsubscriber: func(ctx context.Context, c Connection, s subscription.List) error { + return currySimpleUnsubConn(w)(ctx, c, s) + }, + Handler: func(context.Context, []byte) error { return nil }, + } + require.NoError(t, w.SetupNewConnection(amazingCandidate)) + require.NoError(t, w.FlushChannels(), "FlushChannels must not error") + + w.features.Subscribe = false + w.features.FullPayloadSubscribe = true + require.NoError(t, w.FlushChannels(), "FlushChannels must not error") } func TestDisable(t *testing.T) { From d29893b9eccbaf422abae78c0eda8293ef36fdb3 Mon Sep 17 00:00:00 2001 From: shazbert Date: Fri, 30 Aug 2024 16:22:44 +1000 Subject: [PATCH 054/138] linter: fix --- exchanges/stream/websocket.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index a263f0f7c09..d59caf009b4 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -722,7 +722,7 @@ func (w *Websocket) generateUnsubscribeAndSubscribe(conn Connection, generate fu } subs, unsubs := w.GetChannelDifference(conn, newsubs) if len(unsubs) != 0 && w.features.Unsubscribe { - if err = w.UnsubscribeChannels(conn, unsubs); err != nil { + if err := w.UnsubscribeChannels(conn, unsubs); err != nil { return err } } From 45ff199aa1d5fa975e7e8660d9d9a205afe1f48b Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Tue, 3 Sep 2024 08:54:15 +1000 Subject: [PATCH 055/138] After merge --- exchanges/gateio/gateio_wrapper.go | 8 ++++---- exchanges/stream/websocket_test.go | 7 +++++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 3889fbee180..476d1b562ac 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -224,7 +224,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { // Futures connection - USDT margined err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: futuresWebsocketUsdtURL, - RateLimit: gateioWebsocketRateLimit, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, Handler: func(ctx context.Context, incoming []byte) error { @@ -243,7 +243,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { // Futures connection - BTC margined err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: futuresWebsocketBtcURL, - RateLimit: gateioWebsocketRateLimit, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, Handler: func(ctx context.Context, incoming []byte) error { @@ -263,7 +263,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { // Futures connection - Delivery - USDT margined err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: deliveryRealUSDTTradingURL, - RateLimit: gateioWebsocketRateLimit, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, Handler: func(ctx context.Context, incoming []byte) error { @@ -282,7 +282,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { // Futures connection - Options return g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: optionsWebsocketURL, - RateLimit: gateioWebsocketRateLimit, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, Handler: g.WsHandleOptionsData, diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 06653cdba3b..1e50e1ca7da 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -411,7 +411,7 @@ func TestConnectionMessageErrors(t *testing.T) { err = ws.Connect() require.NoError(t, err) - err = ws.connectionManager[0].Connection.SendRawMessage(websocket.TextMessage, []byte("test")) + err = ws.connectionManager[0].Connection.SendRawMessage(context.Background(), websocket.TextMessage, []byte("test")) require.NoError(t, err) require.NoError(t, err) @@ -1336,7 +1336,10 @@ func TestSetupNewConnection(t *testing.T) { set.UseMultiConnectionManagement = true require.NoError(t, multi.Setup(&set)) - connSetup := &ConnectionSetup{} + err = multi.SetupNewConnection(nil) + require.ErrorIs(t, err, errExchangeConfigEmpty) + + connSetup := &ConnectionSetup{ResponseCheckTimeout: time.Millisecond} err = multi.SetupNewConnection(connSetup) require.ErrorIs(t, err, errDefaultURLIsEmpty) From 06acaacd6d4a813fac0be0edbe7616413fdf7730 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Tue, 3 Sep 2024 09:07:07 +1000 Subject: [PATCH 056/138] Add error connection info --- exchanges/stream/websocket.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 5f767e461bc..f626a7127ea 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -1266,7 +1266,7 @@ func (w *Websocket) Reader(ctx context.Context, conn Connection, handler func(ct return // Connection has been closed } if err := handler(ctx, resp.Raw); err != nil { - w.DataHandler <- err + w.DataHandler <- fmt.Errorf("connection URL:[%v] error: %w", conn.GetURL(), err) } } } From cf2e8d879b85fbde00e4f5ee6891257480ed05e5 Mon Sep 17 00:00:00 2001 From: shazbert Date: Thu, 5 Sep 2024 10:09:04 +1000 Subject: [PATCH 057/138] upgrade to upstream merge --- exchanges/gateio/gateio_websocket.go | 6 +++--- exchanges/gateio/gateio_websocket_request_spot.go | 4 ++-- exchanges/gateio/gateio_ws_delivery_futures.go | 8 ++++---- exchanges/gateio/gateio_ws_futures.go | 8 ++++---- exchanges/gateio/gateio_ws_option.go | 8 ++++---- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 8a54081236e..bb919c226fe 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -694,7 +694,7 @@ func (g *Gateio) GenerateDefaultSubscriptionsSpot() (subscription.List, error) { // handleSubscription sends a websocket message to receive data from the channel func (g *Gateio) handleSubscription(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) error { - payloads, err := g.generatePayload(ctx, event, channelsToSubscribe) + payloads, err := g.generatePayload(ctx, conn, event, channelsToSubscribe) if err != nil { return err } @@ -726,7 +726,7 @@ func (g *Gateio) handleSubscription(ctx context.Context, conn stream.Connection, return errs } -func (g *Gateio) generatePayload(ctx context.Context, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { +func (g *Gateio) generatePayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { if len(channelsToSubscribe) == 0 { return nil, errors.New("cannot generate payload, no channels supplied") } @@ -819,7 +819,7 @@ func (g *Gateio) generatePayload(ctx context.Context, event string, channelsToSu } payload := WsInput{ - ID: g.Counter.IncrementAndGet(), + ID: conn.GenerateMessageID(false), Event: event, Channel: channelsToSubscribe[i].Channel, Payload: params, diff --git a/exchanges/gateio/gateio_websocket_request_spot.go b/exchanges/gateio/gateio_websocket_request_spot.go index 1d688b23aeb..0afed68f2c3 100644 --- a/exchanges/gateio/gateio_websocket_request_spot.go +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -50,7 +50,7 @@ func (g *Gateio) WebsocketLogin(ctx context.Context, conn stream.Connection, cha signature := hex.EncodeToString(mac.Sum(nil)) payload := WebsocketPayload{ - RequestID: strconv.FormatInt(g.Counter.IncrementAndGet(), 10), + RequestID: strconv.FormatInt(conn.GenerateMessageID(false), 10), APIKey: creds.Key, Signature: signature, Timestamp: strconv.FormatInt(tn.Unix(), 10), @@ -254,7 +254,7 @@ func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel string, connS Payload: WebsocketPayload{ // This request ID associated with the payload is the match to the // response. - RequestID: strconv.FormatInt(g.Counter.IncrementAndGet(), 10), + RequestID: strconv.FormatInt(conn.GenerateMessageID(false), 10), RequestParam: paramPayload, Timestamp: strconv.FormatInt(tn.Unix(), 10), }, diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_ws_delivery_futures.go index 1942e48ceab..825eb01e403 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_ws_delivery_futures.go @@ -49,7 +49,7 @@ func (g *Gateio) WsDeliveryFuturesConnect(ctx context.Context, conn stream.Conne return err } pingMessage, err := json.Marshal(WsInput{ - ID: g.Counter.IncrementAndGet(), + ID: conn.GenerateMessageID(false), Time: time.Now().Unix(), // TODO: Func for dynamic time as this will be the same time for every ping message. Channel: futuresPingChannel, }) @@ -122,7 +122,7 @@ func (g *Gateio) DeliveryFuturesUnsubscribe(ctx context.Context, conn stream.Con // handleDeliveryFuturesSubscription sends a websocket message to receive data from the channel func (g *Gateio) handleDeliveryFuturesSubscription(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) error { - payloads, err := g.generateDeliveryFuturesPayload(ctx, event, channelsToSubscribe) + payloads, err := g.generateDeliveryFuturesPayload(ctx, conn, event, channelsToSubscribe) if err != nil { return err } @@ -155,7 +155,7 @@ func (g *Gateio) handleDeliveryFuturesSubscription(ctx context.Context, conn str return errs } -func (g *Gateio) generateDeliveryFuturesPayload(ctx context.Context, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { +func (g *Gateio) generateDeliveryFuturesPayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { if len(channelsToSubscribe) == 0 { return nil, errors.New("cannot generate payload, no channels supplied") } @@ -238,7 +238,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(ctx context.Context, event strin } } outbound = append(outbound, WsInput{ - ID: g.Counter.IncrementAndGet(), + ID: conn.GenerateMessageID(false), Event: event, Channel: channelsToSubscribe[i].Channel, Payload: params, diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index 9fb8f54f210..4602372ec9c 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -68,7 +68,7 @@ func (g *Gateio) WsFuturesConnect(ctx context.Context, conn stream.Connection) e return err } pingMessage, err := json.Marshal(WsInput{ - ID: g.Counter.IncrementAndGet(), + ID: conn.GenerateMessageID(false), Time: time.Now().Unix(), // TODO: Func for dynamic time as this will be the same time for every ping message. Channel: futuresPingChannel, }) @@ -221,7 +221,7 @@ func (g *Gateio) WsHandleFuturesData(_ context.Context, respRaw []byte, a asset. // handleFuturesSubscription sends a websocket message to receive data from the channel func (g *Gateio) handleFuturesSubscription(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) error { - payloads, err := g.generateFuturesPayload(ctx, event, channelsToSubscribe) + payloads, err := g.generateFuturesPayload(ctx, conn, event, channelsToSubscribe) if err != nil { return err } @@ -257,7 +257,7 @@ func (g *Gateio) handleFuturesSubscription(ctx context.Context, conn stream.Conn return nil } -func (g *Gateio) generateFuturesPayload(ctx context.Context, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { +func (g *Gateio) generateFuturesPayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { if len(channelsToSubscribe) == 0 { return nil, errors.New("cannot generate payload, no channels supplied") } @@ -343,7 +343,7 @@ func (g *Gateio) generateFuturesPayload(ctx context.Context, event string, chann } } outbound = append(outbound, WsInput{ - ID: g.Counter.IncrementAndGet(), + ID: conn.GenerateMessageID(false), Event: event, Channel: channelsToSubscribe[i].Channel, Payload: params, diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_ws_option.go index 21e8e88277e..762ad1166e8 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_ws_option.go @@ -78,7 +78,7 @@ func (g *Gateio) WsOptionsConnect(ctx context.Context, conn stream.Connection) e return err } pingMessage, err := json.Marshal(WsInput{ - ID: g.Counter.IncrementAndGet(), + ID: conn.GenerateMessageID(false), Time: time.Now().Unix(), // TODO: Func for dynamic time as this will be the same time for every ping message. Channel: optionsPingChannel, }) @@ -163,7 +163,7 @@ getEnabledPairs: return subscriptions, nil } -func (g *Gateio) generateOptionsPayload(ctx context.Context, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { +func (g *Gateio) generateOptionsPayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { if len(channelsToSubscribe) == 0 { return nil, errors.New("cannot generate payload, no channels supplied") } @@ -265,7 +265,7 @@ func (g *Gateio) generateOptionsPayload(ctx context.Context, event string, chann params...) } payloads[i] = WsInput{ - ID: g.Counter.IncrementAndGet(), + ID: conn.GenerateMessageID(false), Event: event, Channel: channelsToSubscribe[i].Channel, Payload: params, @@ -288,7 +288,7 @@ func (g *Gateio) OptionsUnsubscribe(ctx context.Context, conn stream.Connection, // handleOptionsSubscription sends a websocket message to receive data from the channel func (g *Gateio) handleOptionsSubscription(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) error { - payloads, err := g.generateOptionsPayload(ctx, event, channelsToSubscribe) + payloads, err := g.generateOptionsPayload(ctx, conn, event, channelsToSubscribe) if err != nil { return err } From af986da9d328533a194d89fdc3c79ae36eb89b77 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Mon, 9 Sep 2024 10:59:00 +1000 Subject: [PATCH 058/138] Fix edge case where it does not reconnect made by an already closed connection --- exchanges/stream/websocket.go | 18 +++++++++++++++--- exchanges/stream/websocket_test.go | 20 +++++--------------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index f626a7127ea..665e4ff721f 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -634,11 +634,18 @@ func (w *Websocket) Shutdown() error { defer w.Orderbook.FlushBuffer() + // During the shutdown process, all errors are treated as non-fatal to avoid issues when the connection has already + // been closed. In such cases, attempting to close the connection may result in a + // "failed to send closeNotify alert (but connection was closed anyway)" error. Treating these errors as non-fatal + // prevents the shutdown process from being interrupted, which could otherwise trigger a continuous traffic monitor + // cycle and potentially block the initiation of a new connection. + var nonFatalCloseConnectionErrors error + // Shutdown managed connections for _, wrapper := range w.connectionManager { if wrapper.Connection != nil { if err := wrapper.Connection.Shutdown(); err != nil { - return err + nonFatalCloseConnectionErrors = common.AppendError(nonFatalCloseConnectionErrors, err) } } } @@ -651,12 +658,12 @@ func (w *Websocket) Shutdown() error { if w.Conn != nil { if err := w.Conn.Shutdown(); err != nil { - return err + nonFatalCloseConnectionErrors = common.AppendError(nonFatalCloseConnectionErrors, err) } } if w.AuthConn != nil { if err := w.AuthConn.Shutdown(); err != nil { - return err + nonFatalCloseConnectionErrors = common.AppendError(nonFatalCloseConnectionErrors, err) } } // flush any subscriptions from last connection if needed @@ -675,6 +682,11 @@ func (w *Websocket) Shutdown() error { // the cycle when `Connect` is called again and the connectionMonitor // starts but there is an old error in the channel. drain(w.ReadMessageErrors) + + if nonFatalCloseConnectionErrors != nil { + log.Warnf(log.WebsocketMgr, "%v websocket: shutdown error: %v", w.exchangeName, nonFatalCloseConnectionErrors) + } + return nil } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 1e50e1ca7da..20d35af3d56 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -426,10 +426,7 @@ func TestWebsocket(t *testing.T) { err := ws.SetProxyAddress("garbagio") assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly") - ws.Conn = &dodgyConnection{} - ws.AuthConn = &WebsocketConnection{} ws.setEnabled(true) - err = ws.Setup(defaultSetup) // Sets to enabled again require.NoError(t, err, "Setup may not error") @@ -450,6 +447,7 @@ func TestWebsocket(t *testing.T) { ws.setState(connectedState) + ws.connector = func() error { return errDastardlyReason } err = ws.SetProxyAddress("https://192.168.0.1:1336") assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Connect and error from there") @@ -457,13 +455,9 @@ func TestWebsocket(t *testing.T) { assert.ErrorIs(t, err, errSameProxyAddress, "SetProxyAddress should error correctly") // removing proxy - err = ws.SetProxyAddress("") - assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Shutdown and error from there") - assert.ErrorIs(t, err, errCannotShutdown, "SetProxyAddress should call Shutdown and error from there") + assert.NoError(t, ws.SetProxyAddress("")) - ws.Conn = &WebsocketConnection{} ws.setEnabled(true) - // reinstate proxy err = ws.SetProxyAddress("http://localhost:1337") assert.NoError(t, err, "SetProxyAddress should not error") @@ -471,15 +465,11 @@ func TestWebsocket(t *testing.T) { assert.Equal(t, "wss://testRunningURL", ws.GetWebsocketURL(), "GetWebsocketURL should return correctly") assert.Equal(t, time.Second*5, ws.trafficTimeout, "trafficTimeout should default correctly") + assert.ErrorIs(t, ws.Shutdown(), ErrNotConnected) ws.setState(connectedState) - ws.AuthConn = &dodgyConnection{} - err = ws.Shutdown() - assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy authConn") - assert.ErrorIs(t, err, errCannotShutdown, "Shutdown should error correctly with a dodgy authConn") - - ws.AuthConn = &WebsocketConnection{} - ws.setState(disconnectedState) + assert.NoError(t, ws.Shutdown()) + ws.connector = func() error { return nil } err = ws.Connect() assert.NoError(t, err, "Connect should not error") From 117127856aa7124724d195d23f7ff6dbf048d718 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 12 Sep 2024 15:35:51 +1000 Subject: [PATCH 059/138] stream coverage --- exchanges/stream/websocket.go | 10 ++-- exchanges/stream/websocket_test.go | 93 ++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 5 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 665e4ff721f..962722c8cf9 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -738,11 +738,11 @@ func (w *Websocket) FlushChannels() error { } func (w *Websocket) generateUnsubscribeAndSubscribe(conn Connection, generate func() (subscription.List, error)) error { - newsubs, err := generate() + newSubs, err := generate() if err != nil { return err } - subs, unsubs := w.GetChannelDifference(conn, newsubs) + subs, unsubs := w.GetChannelDifference(conn, newSubs) if len(unsubs) != 0 && w.features.Unsubscribe { if err := w.UnsubscribeChannels(conn, unsubs); err != nil { return err @@ -755,15 +755,15 @@ func (w *Websocket) generateUnsubscribeAndSubscribe(conn Connection, generate fu } func (w *Websocket) generateAndSubscribe(store *subscription.Store, conn Connection, generate func() (subscription.List, error)) error { - newsubs, err := generate() + newSubs, err := generate() if err != nil { return err } - if len(newsubs) == 0 { + if len(newSubs) == 0 { return nil } store.Clear() // Purge subscription list as there will be conflicts - return w.SubscribeToChannels(conn, newsubs) + return w.SubscribeToChannels(conn, newSubs) } // trafficMonitor waits trafficCheckInterval before checking for a trafficAlert diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 20d35af3d56..afa94353af5 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -501,6 +501,12 @@ func TestWebsocket(t *testing.T) { err = ws.Shutdown() assert.NoError(t, err, "Shutdown should not error") ws.Wg.Wait() + + ws.useMultiConnectionManagement = true + + ws.connectionManager = []ConnectionWrapper{{Setup: &ConnectionSetup{URL: "ws://demos.kaazing.com/echo"}, Connection: &WebsocketConnection{}}} + err = ws.SetProxyAddress("https://192.168.0.1:1337") + require.NoError(t, err) } func currySimpleSub(w *Websocket) func(subscription.List) error { @@ -710,6 +716,17 @@ func TestConnectionMonitorNoConnection(t *testing.T) { assert.True(t, ws.IsConnectionMonitorRunning(), "IsConnectionMonitorRunning should return true") err = ws.connectionMonitor() assert.ErrorIs(t, err, errAlreadyRunning, "connectionMonitor should error correctly") + + ws.setState(connectedState) + ws.ReadMessageErrors <- errConnectionFault + select { + case data := <-ws.DataHandler: + err, ok := data.(error) + require.True(t, ok, "DataHandler should return an error") + require.ErrorIs(t, err, errConnectionFault, "DataHandler should return the correct error") + case <-time.After(2 * time.Second): + t.Fatal("DataHandler should return an error") + } } // TestGetSubscription logic test @@ -1462,3 +1479,79 @@ func TestCheckSubscriptions(t *testing.T) { err = ws.checkSubscriptions(nil, subscription.List{{}}) assert.NoError(t, err, "checkSubscriptions should not error") } + +func TestGenerateUnsubscribeAndSubscribe(t *testing.T) { + t.Parallel() + ws := Websocket{subscriptions: subscription.NewStore(), features: &protocol.Features{}} + ws.subscriptions.Add(&subscription.Subscription{Channel: subscription.MyOrdersChannel}) + + generateError := errors.New("foo fighters the generator") + err := ws.generateUnsubscribeAndSubscribe(&WebsocketConnection{}, func() (subscription.List, error) { + return nil, generateError + }) + require.ErrorIs(t, err, generateError) + + err = ws.generateUnsubscribeAndSubscribe(&WebsocketConnection{}, func() (subscription.List, error) { + return subscription.List{{Channel: subscription.CandlesChannel}, {Channel: subscription.OrderbookChannel}}, nil + }) + require.ErrorIs(t, err, common.ErrNilPointer) + + failedSubscriberError := errors.New("failed subscriber") + ws.Subscriber = func(subscription.List) error { return failedSubscriberError } + err = ws.generateUnsubscribeAndSubscribe(&WebsocketConnection{}, func() (subscription.List, error) { + return subscription.List{{Channel: subscription.CandlesChannel}, {Channel: subscription.OrderbookChannel}}, nil + }) + require.ErrorIs(t, err, failedSubscriberError) + + failedUnSubscriberError := errors.New("failed unsubscriber") + ws.Subscriber = func(subscription.List) error { return nil } + ws.Unsubscriber = func(subscription.List) error { return failedUnSubscriberError } + ws.features.Unsubscribe = true + err = ws.generateUnsubscribeAndSubscribe(&WebsocketConnection{}, func() (subscription.List, error) { + return subscription.List{{Channel: subscription.CandlesChannel}, {Channel: subscription.OrderbookChannel}}, nil + }) + require.ErrorIs(t, err, failedUnSubscriberError) + + ws.Unsubscriber = func(subscription.List) error { return nil } + err = ws.generateUnsubscribeAndSubscribe(&WebsocketConnection{}, func() (subscription.List, error) { + return subscription.List{{Channel: subscription.CandlesChannel}, {Channel: subscription.OrderbookChannel}}, nil + }) + require.NoError(t, err) + + ws.Unsubscriber = func(subscription.List) error { return failedUnSubscriberError } + ws.Subscriber = func(subscription.List) error { return failedSubscriberError } + err = ws.generateUnsubscribeAndSubscribe(&WebsocketConnection{}, func() (subscription.List, error) { + return subscription.List{{Channel: subscription.MyOrdersChannel}}, nil + }) + require.NoError(t, err) +} + +func TestGenerateAndSubscribe(t *testing.T) { + t.Parallel() + + ws := Websocket{subscriptions: subscription.NewStore()} + + generateError := errors.New("foo fighters the generator") + err := ws.generateAndSubscribe(ws.subscriptions, &WebsocketConnection{}, func() (subscription.List, error) { + return nil, generateError + }) + require.ErrorIs(t, err, generateError) + + ws.Subscriber = func(subscription.List) error { return nil } + err = ws.generateAndSubscribe(ws.subscriptions, &WebsocketConnection{}, func() (subscription.List, error) { + return subscription.List{{Channel: subscription.CandlesChannel}, {Channel: subscription.OrderbookChannel}}, nil + }) + require.NoError(t, err) + + failedSubscriberError := errors.New("failed subscriber") + ws.Subscriber = func(subscription.List) error { return failedSubscriberError } + err = ws.generateAndSubscribe(ws.subscriptions, &WebsocketConnection{}, func() (subscription.List, error) { + return subscription.List{{Channel: subscription.CandlesChannel}, {Channel: subscription.OrderbookChannel}}, nil + }) + require.ErrorIs(t, err, failedSubscriberError) + + err = ws.generateAndSubscribe(ws.subscriptions, &WebsocketConnection{}, func() (subscription.List, error) { + return nil, nil + }) + require.NoError(t, err) +} From 4be62145a1a036b8d4bae1b2fb8be1f8b164e769 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 12 Sep 2024 16:30:24 +1000 Subject: [PATCH 060/138] glorious: nits --- exchanges/gateio/gateio_test.go | 29 +++++++++++++ exchanges/gateio/gateio_websocket.go | 13 +++--- .../gateio/gateio_ws_delivery_futures.go | 41 +----------------- exchanges/gateio/gateio_ws_futures.go | 43 +------------------ exchanges/gateio/gateio_ws_option.go | 39 +---------------- 5 files changed, 43 insertions(+), 122 deletions(-) diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index 6a1c43198d1..4e5b404be13 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -24,6 +24,8 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" + "github.com/thrasher-corp/gocryptotrader/exchanges/stream" + "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -2980,6 +2982,10 @@ func TestGenerateFuturesDefaultSubscriptions(t *testing.T) { if _, err := g.GenerateFuturesDefaultSubscriptions(currency.USDT); err != nil { t.Error(err) } + + if _, err := g.GenerateFuturesDefaultSubscriptions(currency.BTC); err != nil { + t.Error(err) + } } func TestGenerateOptionsDefaultSubscriptions(t *testing.T) { t.Parallel() @@ -3611,3 +3617,26 @@ func TestGenerateWebsocketMessageID(t *testing.T) { t.Parallel() require.NotEmpty(t, g.GenerateWebsocketMessageID(false)) } + +type DummyConnection struct{ stream.Connection } + +func (d *DummyConnection) GenerateMessageID(bool) int64 { return 1337 } +func (d *DummyConnection) SendMessageReturnResponse(ctx context.Context, signature any, request any) ([]byte, error) { + return []byte(`{"time":1726121320,"time_ms":1726121320745,"id":1,"conn_id":"f903779a148987ca","trace_id":"d8ee37cd14347e4ed298d44e69aedaa7","channel":"spot.tickers","event":"subscribe","payload":["BRETT_USDT"],"result":{"status":"success"},"requestId":"d8ee37cd14347e4ed298d44e69aedaa7"}`), nil +} + +func TestHandleSubscriptions(t *testing.T) { + t.Parallel() + + subs := subscription.List{{Channel: subscription.OrderbookChannel}} + + err := g.handleSubscription(context.Background(), &DummyConnection{}, subscribeEvent, subs, func(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { + return []WsInput{{}}, nil + }) + require.NoError(t, err) + + err = g.handleSubscription(context.Background(), &DummyConnection{}, unsubscribeEvent, subs, func(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { + return []WsInput{{}}, nil + }) + require.NoError(t, err) +} diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 0790f14db73..058ce3ffea8 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -678,9 +678,12 @@ func (g *Gateio) GenerateDefaultSubscriptionsSpot() (subscription.List, error) { return subscriptions, nil } +// GeneratePayload returns the payload for a websocket message +type GeneratePayload func(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) + // handleSubscription sends a websocket message to receive data from the channel -func (g *Gateio) handleSubscription(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) error { - payloads, err := g.generatePayload(ctx, conn, event, channelsToSubscribe) +func (g *Gateio) handleSubscription(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List, generatePayload GeneratePayload) error { + payloads, err := generatePayload(ctx, conn, event, channelsToSubscribe) if err != nil { return err } @@ -699,7 +702,7 @@ func (g *Gateio) handleSubscription(ctx context.Context, conn stream.Connection, errs = common.AppendError(errs, fmt.Errorf("error while %s to channel %s error code: %d message: %s", payloads[k].Event, payloads[k].Channel, resp.Error.Code, resp.Error.Message)) continue } - if payloads[k].Event == subscribeEvent { + if event == subscribeEvent { err = g.Websocket.AddSuccessfulSubscriptions(conn, channelsToSubscribe[k]) } else { err = g.Websocket.RemoveSubscriptions(conn, channelsToSubscribe[k]) @@ -833,12 +836,12 @@ func (g *Gateio) generatePayload(ctx context.Context, conn stream.Connection, ev // SpotSubscribe sends a websocket message to stop receiving data from the channel func (g *Gateio) SpotSubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { - return g.handleSubscription(ctx, conn, subscribeEvent, channelsToUnsubscribe) + return g.handleSubscription(ctx, conn, subscribeEvent, channelsToUnsubscribe, g.generatePayload) } // SpotUnsubscribe sends a websocket message to stop receiving data from the channel func (g *Gateio) SpotUnsubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { - return g.handleSubscription(ctx, conn, unsubscribeEvent, channelsToUnsubscribe) + return g.handleSubscription(ctx, conn, unsubscribeEvent, channelsToUnsubscribe, g.generatePayload) } func (g *Gateio) listOfAssetsCurrencyPairEnabledFor(cp currency.Pair) map[asset.Item]bool { diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_ws_delivery_futures.go index 825eb01e403..18f08c9de63 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_ws_delivery_futures.go @@ -4,13 +4,11 @@ import ( "context" "encoding/json" "errors" - "fmt" "net/http" "strconv" "time" "github.com/gorilla/websocket" - "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" @@ -112,47 +110,12 @@ func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() (subscription.Lis // DeliveryFuturesSubscribe sends a websocket message to stop receiving data from the channel func (g *Gateio) DeliveryFuturesSubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { - return g.handleDeliveryFuturesSubscription(ctx, conn, subscribeEvent, channelsToUnsubscribe) + return g.handleSubscription(ctx, conn, subscribeEvent, channelsToUnsubscribe, g.generateDeliveryFuturesPayload) } // DeliveryFuturesUnsubscribe sends a websocket message to stop receiving data from the channel func (g *Gateio) DeliveryFuturesUnsubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { - return g.handleDeliveryFuturesSubscription(ctx, conn, unsubscribeEvent, channelsToUnsubscribe) -} - -// handleDeliveryFuturesSubscription sends a websocket message to receive data from the channel -func (g *Gateio) handleDeliveryFuturesSubscription(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) error { - payloads, err := g.generateDeliveryFuturesPayload(ctx, conn, event, channelsToSubscribe) - if err != nil { - return err - } - var errs error - var respByte []byte - for i, val := range payloads { - respByte, err = conn.SendMessageReturnResponse(ctx, val.ID, val) - if err != nil { - errs = common.AppendError(errs, err) - continue - } - var resp WsEventResponse - if err = json.Unmarshal(respByte, &resp); err != nil { - errs = common.AppendError(errs, err) - } else { - if resp.Error != nil && resp.Error.Code != 0 { - errs = common.AppendError(errs, fmt.Errorf("error while %s to channel %s error code: %d message: %s", val.Event, val.Channel, resp.Error.Code, resp.Error.Message)) - continue - } - if val.Event == subscribeEvent { - err = g.Websocket.AddSuccessfulSubscriptions(conn, channelsToSubscribe[i]) - } else { - err = g.Websocket.RemoveSubscriptions(conn, channelsToSubscribe[i]) - } - if err != nil { - errs = common.AppendError(errs, err) - } - } - } - return errs + return g.handleSubscription(ctx, conn, unsubscribeEvent, channelsToUnsubscribe, g.generateDeliveryFuturesPayload) } func (g *Gateio) generateDeliveryFuturesPayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index 4602372ec9c..3f4a94375bd 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -11,7 +11,6 @@ import ( "time" "github.com/gorilla/websocket" - "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" @@ -151,12 +150,12 @@ func (g *Gateio) GenerateFuturesDefaultSubscriptions(settlement currency.Code) ( // FuturesSubscribe sends a websocket message to stop receiving data from the channel func (g *Gateio) FuturesSubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { - return g.handleFuturesSubscription(ctx, conn, subscribeEvent, channelsToUnsubscribe) + return g.handleSubscription(ctx, conn, subscribeEvent, channelsToUnsubscribe, g.generateFuturesPayload) } // FuturesUnsubscribe sends a websocket message to stop receiving data from the channel func (g *Gateio) FuturesUnsubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { - return g.handleFuturesSubscription(ctx, conn, unsubscribeEvent, channelsToUnsubscribe) + return g.handleSubscription(ctx, conn, unsubscribeEvent, channelsToUnsubscribe, g.generateFuturesPayload) } // WsHandleFuturesData handles futures websocket data @@ -219,44 +218,6 @@ func (g *Gateio) WsHandleFuturesData(_ context.Context, respRaw []byte, a asset. } } -// handleFuturesSubscription sends a websocket message to receive data from the channel -func (g *Gateio) handleFuturesSubscription(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) error { - payloads, err := g.generateFuturesPayload(ctx, conn, event, channelsToSubscribe) - if err != nil { - return err - } - var errs error - var respByte []byte - for i, val := range payloads { - respByte, err = conn.SendMessageReturnResponse(ctx, val.ID, val) - if err != nil { - errs = common.AppendError(errs, err) - continue - } - var resp WsEventResponse - if err = json.Unmarshal(respByte, &resp); err != nil { - errs = common.AppendError(errs, err) - } else { - if resp.Error != nil && resp.Error.Code != 0 { - errs = common.AppendError(errs, fmt.Errorf("error while %s to channel %s error code: %d message: %s", val.Event, val.Channel, resp.Error.Code, resp.Error.Message)) - continue - } - if val.Event == subscribeEvent { - err = g.Websocket.AddSuccessfulSubscriptions(conn, channelsToSubscribe[i]) - } else { - err = g.Websocket.RemoveSubscriptions(conn, channelsToSubscribe[i]) - } - if err != nil { - errs = common.AppendError(errs, err) - } - } - } - if errs != nil { - return errs - } - return nil -} - func (g *Gateio) generateFuturesPayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { if len(channelsToSubscribe) == 0 { return nil, errors.New("cannot generate payload, no channels supplied") diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_ws_option.go index 762ad1166e8..488bedc50fe 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_ws_option.go @@ -11,7 +11,6 @@ import ( "time" "github.com/gorilla/websocket" - "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" @@ -278,46 +277,12 @@ func (g *Gateio) generateOptionsPayload(ctx context.Context, conn stream.Connect // OptionsSubscribe sends a websocket message to stop receiving data for asset type options func (g *Gateio) OptionsSubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { - return g.handleOptionsSubscription(ctx, conn, subscribeEvent, channelsToUnsubscribe) + return g.handleSubscription(ctx, conn, subscribeEvent, channelsToUnsubscribe, g.generateOptionsPayload) } // OptionsUnsubscribe sends a websocket message to stop receiving data for asset type options func (g *Gateio) OptionsUnsubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { - return g.handleOptionsSubscription(ctx, conn, unsubscribeEvent, channelsToUnsubscribe) -} - -// handleOptionsSubscription sends a websocket message to receive data from the channel -func (g *Gateio) handleOptionsSubscription(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) error { - payloads, err := g.generateOptionsPayload(ctx, conn, event, channelsToSubscribe) - if err != nil { - return err - } - var errs error - for k := range payloads { - result, err := conn.SendMessageReturnResponse(ctx, payloads[k].ID, payloads[k]) - if err != nil { - errs = common.AppendError(errs, err) - continue - } - var resp WsEventResponse - if err = json.Unmarshal(result, &resp); err != nil { - errs = common.AppendError(errs, err) - } else { - if resp.Error != nil && resp.Error.Code != 0 { - errs = common.AppendError(errs, fmt.Errorf("error while %s to channel %s asset type: options error code: %d message: %s", payloads[k].Event, payloads[k].Channel, resp.Error.Code, resp.Error.Message)) - continue - } - if payloads[k].Event == subscribeEvent { - err = g.Websocket.AddSuccessfulSubscriptions(conn, channelsToSubscribe[k]) - } else { - err = g.Websocket.RemoveSubscriptions(conn, channelsToSubscribe[k]) - } - if err != nil { - errs = common.AppendError(errs, err) - } - } - } - return errs + return g.handleSubscription(ctx, conn, unsubscribeEvent, channelsToUnsubscribe, g.generateOptionsPayload) } // WsHandleOptionsData handles options websocket data From 81cba36770252f00c5e2fde6f77f0520960fea7e Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 13 Sep 2024 09:45:27 +1000 Subject: [PATCH 061/138] glorious: nits removed asset error handling in stream package --- exchanges/gateio/gateio_websocket.go | 22 ++++---- exchanges/gateio/gateio_wrapper.go | 3 +- .../gateio/gateio_ws_delivery_futures.go | 12 +++-- exchanges/gateio/gateio_ws_futures.go | 50 ++++++++++--------- exchanges/gateio/gateio_ws_option.go | 11 ++-- exchanges/stream/websocket.go | 11 +--- 6 files changed, 57 insertions(+), 52 deletions(-) diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 058ce3ffea8..c5af3010ed0 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -615,10 +615,7 @@ func (g *Gateio) processCrossMarginLoans(data []byte) error { func (g *Gateio) GenerateDefaultSubscriptionsSpot() (subscription.List, error) { channelsToSubscribe := defaultSubscriptions if g.Websocket.CanUseAuthenticatedEndpoints() { - channelsToSubscribe = append(channelsToSubscribe, []string{ - crossMarginBalanceChannel, - marginBalancesChannel, - spotBalancesChannel}...) + channelsToSubscribe = append(channelsToSubscribe, []string{crossMarginBalanceChannel, marginBalancesChannel, spotBalancesChannel}...) } if g.IsSaveTradeDataEnabled() || g.IsTradeFeedEnabled() { @@ -628,20 +625,25 @@ func (g *Gateio) GenerateDefaultSubscriptionsSpot() (subscription.List, error) { var subscriptions subscription.List var err error for i := range channelsToSubscribe { - var pairs []currency.Pair + var pairs currency.Pairs var assetType asset.Item switch channelsToSubscribe[i] { case marginBalancesChannel: assetType = asset.Margin - pairs, err = g.GetEnabledPairs(asset.Margin) + if pairs, err = g.GetEnabledPairs(asset.Margin); err != nil && !errors.Is(err, asset.ErrNotEnabled) { + return nil, err + } case crossMarginBalanceChannel: assetType = asset.CrossMargin - pairs, err = g.GetEnabledPairs(asset.CrossMargin) + if pairs, err = g.GetEnabledPairs(asset.CrossMargin); err != nil && !errors.Is(err, asset.ErrNotEnabled) { + return nil, err + } default: - // TODO: Check and add balance support as spot balances can be - // subscribed without a currency pair supplied. + // TODO: Check and add balance support as spot balances can be subscribed without a currency pair supplied. assetType = asset.Spot - pairs, err = g.GetEnabledPairs(asset.Spot) + if pairs, err = g.GetEnabledPairs(asset.Spot); err != nil && !errors.Is(err, asset.ErrNotEnabled) { + return nil, err + } } if err != nil { if errors.Is(err, asset.ErrNotEnabled) { diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 476d1b562ac..897097eb59c 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -155,11 +155,12 @@ func (g *Gateio) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - // TODO: Add websocket margin and cross margin support. + // TODO: Majority of margin REST endpoints are labelled as deprecated on the API docs. These will need to be removed. err = g.DisableAssetWebsocketSupport(asset.Margin) if err != nil { log.Errorln(log.ExchangeSys, err) } + // TODO: Add websocket cross margin support. err = g.DisableAssetWebsocketSupport(asset.CrossMargin) if err != nil { log.Errorln(log.ExchangeSys, err) diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_ws_delivery_futures.go index 18f08c9de63..00cc59d7698 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_ws_delivery_futures.go @@ -78,13 +78,17 @@ func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() (subscription.Lis futuresBalancesChannel, ) } - pairs, err := g.GetEnabledPairs(asset.DeliveryFutures) - if err != nil { - return nil, err - } var subscriptions subscription.List for i := range channelsToSubscribe { + pairs, err := g.GetEnabledPairs(asset.DeliveryFutures) + if err != nil { + if errors.Is(err, asset.ErrNotEnabled) { + continue // skip if not enabled + } + return nil, err + } + for j := range pairs { params := make(map[string]interface{}) switch channelsToSubscribe[i] { diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index 3f4a94375bd..9ff89069b99 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -93,34 +93,37 @@ func (g *Gateio) GenerateFuturesDefaultSubscriptions(settlement currency.Code) ( futuresBalancesChannel, ) } - pairs, err := g.GetEnabledPairs(asset.Futures) - if err != nil { - return nil, err - } - switch { - case settlement.Equal(currency.USDT): - pairs, err = pairs.GetPairsByQuote(currency.USDT) + var subscriptions subscription.List + for i := range channelsToSubscribe { + pairs, err := g.GetEnabledPairs(asset.Futures) if err != nil { + if errors.Is(err, asset.ErrNotEnabled) { + continue // skip if not enabled + } return nil, err } - case settlement.Equal(currency.BTC): - offset := 0 - for x := range pairs { - if pairs[x].Quote.Equal(currency.USDT) { - continue // skip USDT pairs + + switch { + case settlement.Equal(currency.USDT): + pairs, err = pairs.GetPairsByQuote(currency.USDT) + if err != nil { + return nil, err } - pairs[offset] = pairs[x] - offset++ + case settlement.Equal(currency.BTC): + offset := 0 + for x := range pairs { + if pairs[x].Quote.Equal(currency.USDT) { + continue // skip USDT pairs + } + pairs[offset] = pairs[x] + offset++ + } + pairs = pairs[:offset] + default: + return nil, fmt.Errorf("settlement currency %s not supported", settlement) } - pairs = pairs[:offset] - default: - return nil, fmt.Errorf("settlement currency %s not supported", settlement) - } - subscriptions := make(subscription.List, len(channelsToSubscribe)*len(pairs)) - count := 0 - for i := range channelsToSubscribe { for j := range pairs { params := make(map[string]interface{}) switch channelsToSubscribe[i] { @@ -137,12 +140,11 @@ func (g *Gateio) GenerateFuturesDefaultSubscriptions(settlement currency.Code) ( if err != nil { return nil, err } - subscriptions[count] = &subscription.Subscription{ + subscriptions = append(subscriptions, &subscription.Subscription{ Channel: channelsToSubscribe[i], Pairs: currency.Pairs{fpair.Upper()}, Params: params, - } - count++ + }) } } return subscriptions, nil diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_ws_option.go index 488bedc50fe..80f2309db1d 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_ws_option.go @@ -120,11 +120,14 @@ func (g *Gateio) GenerateOptionsDefaultSubscriptions() (subscription.List, error } getEnabledPairs: var subscriptions subscription.List - pairs, err := g.GetEnabledPairs(asset.Options) - if err != nil { - return nil, err - } for i := range channelsToSubscribe { + pairs, err := g.GetEnabledPairs(asset.Options) + if err != nil { + if errors.Is(err, asset.ErrNotEnabled) { + continue // skip if not enabled + } + return nil, err + } for j := range pairs { params := make(map[string]interface{}) switch channelsToSubscribe[i] { diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 962722c8cf9..2f460d479f0 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -10,7 +10,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/config" - "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" @@ -385,12 +384,6 @@ func (w *Websocket) Connect() error { subs, err := w.connectionManager[i].Setup.GenerateSubscriptions() // regenerate state on new connection if err != nil { - if errors.Is(err, asset.ErrNotEnabled) { - if w.verbose { - log.Warnf(log.WebsocketMgr, "%s websocket: %v", w.exchangeName, err) - } - continue // Non-fatal error, we can continue to the next connection - } multiConnectFatalError = fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) break } @@ -706,7 +699,7 @@ func (w *Websocket) FlushChannels() error { } for x := range w.connectionManager { err := w.generateUnsubscribeAndSubscribe(w.connectionManager[x].Connection, w.connectionManager[x].Setup.GenerateSubscriptions) - if err != nil && !errors.Is(err, asset.ErrNotEnabled) { + if err != nil { return err } } @@ -724,7 +717,7 @@ func (w *Websocket) FlushChannels() error { for x := range w.connectionManager { err := w.generateAndSubscribe(w.connectionManager[x].Subscriptions, w.connectionManager[x].Connection, w.connectionManager[x].Setup.GenerateSubscriptions) - if err != nil && errors.Is(err, asset.ErrNotEnabled) { + if err != nil { return err } } From bdc6954d83b0eb01d6035138531db20bc28f0bba Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 13 Sep 2024 10:03:32 +1000 Subject: [PATCH 062/138] linter: fix --- exchanges/gateio/gateio_test.go | 6 +++--- exchanges/stream/websocket_test.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index f51a3fce518..0becc47d3ed 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -3675,7 +3675,7 @@ func BenchmarkTime(b *testing.B) { type DummyConnection struct{ stream.Connection } func (d *DummyConnection) GenerateMessageID(bool) int64 { return 1337 } -func (d *DummyConnection) SendMessageReturnResponse(ctx context.Context, signature any, request any) ([]byte, error) { +func (d *DummyConnection) SendMessageReturnResponse(context.Context, any, any) ([]byte, error) { return []byte(`{"time":1726121320,"time_ms":1726121320745,"id":1,"conn_id":"f903779a148987ca","trace_id":"d8ee37cd14347e4ed298d44e69aedaa7","channel":"spot.tickers","event":"subscribe","payload":["BRETT_USDT"],"result":{"status":"success"},"requestId":"d8ee37cd14347e4ed298d44e69aedaa7"}`), nil } @@ -3684,12 +3684,12 @@ func TestHandleSubscriptions(t *testing.T) { subs := subscription.List{{Channel: subscription.OrderbookChannel}} - err := g.handleSubscription(context.Background(), &DummyConnection{}, subscribeEvent, subs, func(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { + err := g.handleSubscription(context.Background(), &DummyConnection{}, subscribeEvent, subs, func(context.Context, stream.Connection, string, subscription.List) ([]WsInput, error) { return []WsInput{{}}, nil }) require.NoError(t, err) - err = g.handleSubscription(context.Background(), &DummyConnection{}, unsubscribeEvent, subs, func(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { + err = g.handleSubscription(context.Background(), &DummyConnection{}, unsubscribeEvent, subs, func(context.Context, stream.Connection, string, subscription.List) ([]WsInput, error) { return []WsInput{{}}, nil }) require.NoError(t, err) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 136677111ab..2887acfd759 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -1490,7 +1490,7 @@ func TestRemoveURLQueryString(t *testing.T) { func TestGenerateUnsubscribeAndSubscribe(t *testing.T) { t.Parallel() ws := Websocket{subscriptions: subscription.NewStore(), features: &protocol.Features{}} - ws.subscriptions.Add(&subscription.Subscription{Channel: subscription.MyOrdersChannel}) + require.NoError(t, ws.subscriptions.Add(&subscription.Subscription{Channel: subscription.MyOrdersChannel})) generateError := errors.New("foo fighters the generator") err := ws.generateUnsubscribeAndSubscribe(&WebsocketConnection{}, func() (subscription.List, error) { From 3fe44ca0d9d584b00e6ff50b0d98e96d304ec09d Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 13 Sep 2024 10:05:31 +1000 Subject: [PATCH 063/138] rm block --- exchanges/gateio/gateio_websocket.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index c5af3010ed0..a4185d1c3a5 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -645,12 +645,6 @@ func (g *Gateio) GenerateDefaultSubscriptionsSpot() (subscription.List, error) { return nil, err } } - if err != nil { - if errors.Is(err, asset.ErrNotEnabled) { - continue // Skip if asset is not enabled. - } - return nil, err - } for j := range pairs { params := make(map[string]interface{}) From 818584fbb4d65c78eb442ab8043ee41405e94928 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 13 Sep 2024 10:24:01 +1000 Subject: [PATCH 064/138] Add basic readme --- exchanges/stream/README.md | 126 +++++++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 exchanges/stream/README.md diff --git a/exchanges/stream/README.md b/exchanges/stream/README.md new file mode 100644 index 00000000000..65d16fc9fc7 --- /dev/null +++ b/exchanges/stream/README.md @@ -0,0 +1,126 @@ +# GoCryptoTrader Exchange Stream Package + +This package is part of the GoCryptoTrader project and is responsible for handling exchange streaming data. + +## Overview + +The `stream` package provides functionalities to connect to various cryptocurrency exchanges and handle real-time data streams. + +## Features + +- Handle real-time market data streams +- Unified interface for managing data streams + +## Usage + +Here is a basic example of how to setup the `stream` package for websocket: + +```go +package main + +import ( + "github.com/thrasher-corp/gocryptotrader/exchanges/stream" + exchange "github.com/thrasher-corp/gocryptotrader/exchanges" + "github.com/thrasher-corp/gocryptotrader/exchanges/request" +) + +type Exchange struct { + exchange.Base +} + +// In the exchange wrapper this will set up the initial pointer field provided by exchange.Base +func (e *Exchange) SetDefault() { + e.Websocket = stream.NewWebsocket() + e.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit + e.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout + e.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit +} + +// In the exchange wrapper this is the original setup pattern for the websocket services +func (e *Exchange) Setup(exch *config.Exchange) error { + // This sets up global connection, sub, unsub and generate subscriptions for each connection defined below. + if err := e.Websocket.Setup(&stream.WebsocketSetup{ + ExchangeConfig: exch, + DefaultURL: connectionURLString, + RunningURL: connectionURLString, + Connector: e.WsConnect, + Subscriber: e.Subscribe, + Unsubscriber: e.Unsubscribe, + GenerateSubscriptions: e.GenerateDefaultSubscriptions, + Features: &e.Features.Supports.WebsocketCapabilities, + MaxWebsocketSubscriptionsPerConnection: 240, + OrderbookBufferConfig: buffer.Config{ Checksum: e.CalculateUpdateOrderbookChecksum }, + }); err != nil { + return err + } + + // This is a public websocket connection + if err := ok.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + URL: connectionURLString, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exchangeWebsocketResponseMaxLimit, + RateLimit: request.NewRateLimitWithWeight(time.Second, 2, 1), + }); err != nil { + return err + } + + // This is a private websocket connection + return ok.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + URL: privateConnectionURLString, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exchangeWebsocketResponseMaxLimit, + Authenticated: true, + RateLimit: request.NewRateLimitWithWeight(time.Second, 2, 1), + }) +} + +// The example below provides the now optional multi connection management system which allows for more connections +// to be maintained and established based off URL, connections types, asset types etc. +func (e *Exchange) Setup(exch *config.Exchange) error { + // This sets up global connection, sub, unsub and generate subscriptions for each connection defined below. + if err := e.Websocket.Setup(&stream.WebsocketSetup{ + ExchangeConfig: exch, + Features: &e.Features.Supports.WebsocketCapabilities, + FillsFeed: e.Features.Enabled.FillsFeed, + TradeFeed: e.Features.Enabled.TradeFeed, + UseMultiConnectionManagement: true, + }) + if err != nil { + return err + } + // Spot connection + err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + URL: connectionURLStringForSpot, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + // Custom handlers for the specific connection: + Handler: e.WsHandleSpotData, + Subscriber: e.SpotSubscribe, + Unsubscriber: e.SpotUnsubscribe, + GenerateSubscriptions: e.GenerateDefaultSubscriptionsSpot, + Connector: e.WsConnectSpot, + BespokeGenerateMessageID: e.GenerateWebsocketMessageID, + }) + if err != nil { + return err + } + // Futures connection - USDT margined + err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + URL: connectionURLStringForSpotForFutures, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + // Custom handlers for the specific connection: + Handler: func(ctx context.Context, incoming []byte) error { return e.WsHandleFuturesData(ctx, incoming, asset.Futures) }, + Subscriber: e.FuturesSubscribe, + Unsubscriber: e.FuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return e.GenerateFuturesDefaultSubscriptions(currency.USDT) }, + Connector: e.WsFuturesConnect, + BespokeGenerateMessageID: e.GenerateWebsocketMessageID, + }) + if err != nil { + return err + } +} +``` From 0709ba30eaab1b702eb19bc0230dc4100ac55215 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Mon, 16 Sep 2024 14:30:57 +1000 Subject: [PATCH 065/138] fix asset enabled flush cycle for multi connection --- engine/rpcserver.go | 7 + exchanges/protocol/features.go | 65 +++++---- exchanges/stream/websocket.go | 159 ++++++++++++----------- exchanges/stream/websocket_connection.go | 5 +- exchanges/stream/websocket_test.go | 153 +++++----------------- exchanges/stream/websocket_types.go | 4 +- 6 files changed, 158 insertions(+), 235 deletions(-) diff --git a/engine/rpcserver.go b/engine/rpcserver.go index 43b5e3814f7..72e7e76e8b6 100644 --- a/engine/rpcserver.go +++ b/engine/rpcserver.go @@ -2935,6 +2935,13 @@ func (s *RPCServer) SetExchangeAsset(_ context.Context, r *gctrpc.SetExchangeAss return nil, err } + if base.IsWebsocketEnabled() && base.Websocket.IsConnected() { + err = exch.FlushWebsocketChannels() + if err != nil { + return nil, err + } + } + return &gctrpc.GenericResponse{Status: MsgStatusSuccess}, nil } diff --git a/exchanges/protocol/features.go b/exchanges/protocol/features.go index b3c0c66662b..d7a0a43d37d 100644 --- a/exchanges/protocol/features.go +++ b/exchanges/protocol/features.go @@ -3,40 +3,37 @@ package protocol // Features holds all variables for the exchanges supported features // for a protocol (e.g REST or Websocket) type Features struct { - TickerBatching bool `json:"tickerBatching,omitempty"` - AutoPairUpdates bool `json:"autoPairUpdates,omitempty"` - AccountBalance bool `json:"accountBalance,omitempty"` - CryptoDeposit bool `json:"cryptoDeposit,omitempty"` - CryptoWithdrawal bool `json:"cryptoWithdrawal,omitempty"` - FiatWithdraw bool `json:"fiatWithdraw,omitempty"` - GetOrder bool `json:"getOrder,omitempty"` - GetOrders bool `json:"getOrders,omitempty"` - CancelOrders bool `json:"cancelOrders,omitempty"` - CancelOrder bool `json:"cancelOrder,omitempty"` - SubmitOrder bool `json:"submitOrder,omitempty"` - SubmitOrders bool `json:"submitOrders,omitempty"` - ModifyOrder bool `json:"modifyOrder,omitempty"` - DepositHistory bool `json:"depositHistory,omitempty"` - WithdrawalHistory bool `json:"withdrawalHistory,omitempty"` - TradeHistory bool `json:"tradeHistory,omitempty"` - UserTradeHistory bool `json:"userTradeHistory,omitempty"` - TradeFee bool `json:"tradeFee,omitempty"` - FiatDepositFee bool `json:"fiatDepositFee,omitempty"` - FiatWithdrawalFee bool `json:"fiatWithdrawalFee,omitempty"` - CryptoDepositFee bool `json:"cryptoDepositFee,omitempty"` - CryptoWithdrawalFee bool `json:"cryptoWithdrawalFee,omitempty"` - TickerFetching bool `json:"tickerFetching,omitempty"` - KlineFetching bool `json:"klineFetching,omitempty"` - TradeFetching bool `json:"tradeFetching,omitempty"` - OrderbookFetching bool `json:"orderbookFetching,omitempty"` - AccountInfo bool `json:"accountInfo,omitempty"` - FiatDeposit bool `json:"fiatDeposit,omitempty"` - DeadMansSwitch bool `json:"deadMansSwitch,omitempty"` - FundingRateFetching bool `json:"fundingRateFetching"` - PredictedFundingRate bool `json:"predictedFundingRate,omitempty"` - // FullPayloadSubscribe flushes and changes full subscription on websocket - // connection by subscribing with full default stream channel list - FullPayloadSubscribe bool `json:"fullPayloadSubscribe,omitempty"` + TickerBatching bool `json:"tickerBatching,omitempty"` + AutoPairUpdates bool `json:"autoPairUpdates,omitempty"` + AccountBalance bool `json:"accountBalance,omitempty"` + CryptoDeposit bool `json:"cryptoDeposit,omitempty"` + CryptoWithdrawal bool `json:"cryptoWithdrawal,omitempty"` + FiatWithdraw bool `json:"fiatWithdraw,omitempty"` + GetOrder bool `json:"getOrder,omitempty"` + GetOrders bool `json:"getOrders,omitempty"` + CancelOrders bool `json:"cancelOrders,omitempty"` + CancelOrder bool `json:"cancelOrder,omitempty"` + SubmitOrder bool `json:"submitOrder,omitempty"` + SubmitOrders bool `json:"submitOrders,omitempty"` + ModifyOrder bool `json:"modifyOrder,omitempty"` + DepositHistory bool `json:"depositHistory,omitempty"` + WithdrawalHistory bool `json:"withdrawalHistory,omitempty"` + TradeHistory bool `json:"tradeHistory,omitempty"` + UserTradeHistory bool `json:"userTradeHistory,omitempty"` + TradeFee bool `json:"tradeFee,omitempty"` + FiatDepositFee bool `json:"fiatDepositFee,omitempty"` + FiatWithdrawalFee bool `json:"fiatWithdrawalFee,omitempty"` + CryptoDepositFee bool `json:"cryptoDepositFee,omitempty"` + CryptoWithdrawalFee bool `json:"cryptoWithdrawalFee,omitempty"` + TickerFetching bool `json:"tickerFetching,omitempty"` + KlineFetching bool `json:"klineFetching,omitempty"` + TradeFetching bool `json:"tradeFetching,omitempty"` + OrderbookFetching bool `json:"orderbookFetching,omitempty"` + AccountInfo bool `json:"accountInfo,omitempty"` + FiatDeposit bool `json:"fiatDeposit,omitempty"` + DeadMansSwitch bool `json:"deadMansSwitch,omitempty"` + FundingRateFetching bool `json:"fundingRateFetching"` + PredictedFundingRate bool `json:"predictedFundingRate,omitempty"` Subscribe bool `json:"subscribe,omitempty"` Unsubscribe bool `json:"unsubscribe,omitempty"` AuthenticatedEndpoints bool `json:"authenticatedEndpoints,omitempty"` diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 4f05d3805b8..1803d67438a 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -64,7 +64,7 @@ var ( errAlreadyReconnecting = errors.New("websocket in the process of reconnection") errConnSetup = errors.New("error in connection setup") errNoPendingConnections = errors.New("no pending connections, call SetupNewConnection first") - errConnectionCandidateDuplication = errors.New("connection candidate duplication") + errConnectionWrapperDuplication = errors.New("connection wrapper duplication") errCannotChangeConnectionURL = errors.New("cannot change connection URL when using multi connection management") errExchangeConfigEmpty = errors.New("exchange config is empty") ) @@ -246,7 +246,7 @@ func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { if w.useMultiConnectionManagement { // The connection and supporting functions are defined per connection - // and the connection candidate is stored in the connection manager. + // and the connection wrapper is stored in the connection manager. if c.URL == "" { return fmt.Errorf("%w: %w", errConnSetup, errDefaultURLIsEmpty) } @@ -268,7 +268,7 @@ func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { for x := range w.connectionManager { if w.connectionManager[x].Setup.URL == c.URL { - return fmt.Errorf("%w: %w", errConnSetup, errConnectionCandidateDuplication) + return fmt.Errorf("%w: %w", errConnSetup, errConnectionWrapperDuplication) } } @@ -303,7 +303,7 @@ func (w *Websocket) getConnectionFromSetup(c *ConnectionSetup) *WebsocketConnect ResponseMaxLimit: c.ResponseMaxLimit, Traffic: w.TrafficAlert, readMessageErrors: w.ReadMessageErrors, - ShutdownC: w.ShutdownC, + shutdown: make(chan struct{}), // Call shutdown to close the connection Wg: &w.Wg, Match: w.Match, RateLimit: c.RateLimit, @@ -637,19 +637,18 @@ func (w *Websocket) Shutdown() error { var nonFatalCloseConnectionErrors error // Shutdown managed connections - for _, wrapper := range w.connectionManager { - if wrapper.Connection != nil { - if err := wrapper.Connection.Shutdown(); err != nil { + for x := range w.connectionManager { + if w.connectionManager[x].Connection != nil { + if err := w.connectionManager[x].Connection.Shutdown(); err != nil { nonFatalCloseConnectionErrors = common.AppendError(nonFatalCloseConnectionErrors, err) } + w.connectionManager[x].Connection = nil + // Flush any subscriptions from last connection across any managed connections + w.connectionManager[x].Subscriptions.Clear() } } // Clean map of old connections clear(w.connections) - // Flush any subscriptions from last connection across any managed connections - for x := range w.connectionManager { - w.connectionManager[x].Subscriptions.Clear() - } if w.Conn != nil { if err := w.Conn.Shutdown(); err != nil { @@ -695,70 +694,76 @@ func (w *Websocket) FlushChannels() error { return fmt.Errorf("%s %w", w.exchangeName, ErrNotConnected) } - if w.features.Subscribe { - if !w.useMultiConnectionManagement { - return w.generateUnsubscribeAndSubscribe(nil, w.GenerateSubs) + // If the exchange does not support subscribing and or unsubscribing the full connection needs to be flushed to + // maintain consistency. + if !w.features.Subscribe || !w.features.Unsubscribe { + if err := w.Shutdown(); err != nil { + return err } - for x := range w.connectionManager { - err := w.generateUnsubscribeAndSubscribe(w.connectionManager[x].Connection, w.connectionManager[x].Setup.GenerateSubscriptions) - if err != nil { - return err - } + return w.Connect() + } + + if !w.useMultiConnectionManagement { + newSubs, err := w.GenerateSubs() + if err != nil { + return err } - return nil - } else if w.features.FullPayloadSubscribe { - // FullPayloadSubscribe means that the endpoint requires all - // subscriptions to be sent via the websocket connection e.g. if you are - // subscribed to ticker and orderbook but require trades as well, you - // would need to send ticker, orderbook and trades channel subscription - // messages. + subs, unsubs := w.GetChannelDifference(nil, newSubs) + if err := w.UnsubscribeChannels(nil, unsubs); err != nil { + return err + } + if len(subs) == 0 { + return nil + } + return w.SubscribeToChannels(nil, subs) + } - if !w.useMultiConnectionManagement { - return w.generateAndSubscribe(w.subscriptions, nil, w.GenerateSubs) + for x := range w.connectionManager { + newSubs, err := w.connectionManager[x].Setup.GenerateSubscriptions() + if err != nil { + return err } - for x := range w.connectionManager { - err := w.generateAndSubscribe(w.connectionManager[x].Subscriptions, w.connectionManager[x].Connection, w.connectionManager[x].Setup.GenerateSubscriptions) - if err != nil { + // Case if there is nothing to unsubscribe from and the connection is nil + if len(newSubs) == 0 && w.connectionManager[x].Connection == nil { + continue + } + + // If there are subscriptions to subscribe to but no connection to subscribe to, establish a new connection. + if w.connectionManager[x].Connection == nil { + conn := w.getConnectionFromSetup(w.connectionManager[x].Setup) + if err := w.connectionManager[x].Setup.Connector(context.TODO(), conn); err != nil { return err } + w.Wg.Add(1) + go w.Reader(context.TODO(), conn, w.connectionManager[x].Setup.Handler) + w.connections[conn] = &w.connectionManager[x] + w.connectionManager[x].Connection = conn } - return nil - } - if err := w.Shutdown(); err != nil { - return err - } - return w.Connect() -} + subs, unsubs := w.GetChannelDifference(w.connectionManager[x].Connection, newSubs) -func (w *Websocket) generateUnsubscribeAndSubscribe(conn Connection, generate func() (subscription.List, error)) error { - newSubs, err := generate() - if err != nil { - return err - } - subs, unsubs := w.GetChannelDifference(conn, newSubs) - if len(unsubs) != 0 && w.features.Unsubscribe { - if err := w.UnsubscribeChannels(conn, unsubs); err != nil { - return err + if len(unsubs) != 0 { + if err := w.UnsubscribeChannels(w.connectionManager[x].Connection, unsubs); err != nil { + return err + } + } + if len(subs) != 0 { + if err := w.SubscribeToChannels(w.connectionManager[x].Connection, subs); err != nil { + return err + } } - } - if len(subs) == 0 { - return nil - } - return w.SubscribeToChannels(conn, subs) -} -func (w *Websocket) generateAndSubscribe(store *subscription.Store, conn Connection, generate func() (subscription.List, error)) error { - newSubs, err := generate() - if err != nil { - return err - } - if len(newSubs) == 0 { - return nil + // If there are no subscriptions to subscribe to, close the connection as it is no longer needed. + if w.connectionManager[x].Subscriptions.Len() == 0 { + delete(w.connections, w.connectionManager[x].Connection) // Remove from lookup map + if err := w.connectionManager[x].Connection.Shutdown(); err != nil { + log.Warnf(log.WebsocketMgr, "%v websocket: failed to shutdown connection: %v", w.exchangeName, err) + } + w.connectionManager[x].Connection = nil + } } - store.Clear() // Purge subscription list as there will be conflicts - return w.SubscribeToChannels(conn, newSubs) + return nil } // trafficMonitor waits trafficCheckInterval before checking for a trafficAlert @@ -1008,8 +1013,8 @@ func (w *Websocket) GetName() string { // and the new subscription list when pairs are disabled or enabled. func (w *Websocket) GetChannelDifference(conn Connection, newSubs subscription.List) (sub, unsub subscription.List) { var subscriptionStore **subscription.Store - if candidate, ok := w.connections[conn]; ok { - subscriptionStore = &candidate.Subscriptions + if wrapper, ok := w.connections[conn]; ok { + subscriptionStore = &wrapper.Subscriptions } else { subscriptionStore = &w.subscriptions } @@ -1024,9 +1029,9 @@ func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.L if len(channels) == 0 { return nil // No channels to unsubscribe from is not an error } - if candidate, ok := w.connections[conn]; ok { - return w.unsubscribe(candidate.Subscriptions, channels, func(channels subscription.List) error { - return candidate.Setup.Unsubscriber(context.TODO(), conn, channels) + if wrapper, ok := w.connections[conn]; ok { + return w.unsubscribe(wrapper.Subscriptions, channels, func(channels subscription.List) error { + return wrapper.Setup.Unsubscriber(context.TODO(), conn, channels) }) } return w.unsubscribe(w.subscriptions, channels, func(channels subscription.List) error { @@ -1070,8 +1075,8 @@ func (w *Websocket) SubscribeToChannels(conn Connection, subs subscription.List) return err } - if candidate, ok := w.connections[conn]; ok { - return candidate.Setup.Subscriber(context.TODO(), conn, subs) + if wrapper, ok := w.connections[conn]; ok { + return wrapper.Setup.Subscriber(context.TODO(), conn, subs) } if w.Subscriber == nil { @@ -1091,8 +1096,8 @@ func (w *Websocket) AddSubscriptions(conn Connection, subs ...*subscription.Subs return fmt.Errorf("%w: AddSubscriptions called on nil Websocket", common.ErrNilPointer) } var subscriptionStore **subscription.Store - if candidate, ok := w.connections[conn]; ok { - subscriptionStore = &candidate.Subscriptions + if wrapper, ok := w.connections[conn]; ok { + subscriptionStore = &wrapper.Subscriptions } else { subscriptionStore = &w.subscriptions } @@ -1121,8 +1126,8 @@ func (w *Websocket) AddSuccessfulSubscriptions(conn Connection, subs ...*subscri } var subscriptionStore **subscription.Store - if candidate, ok := w.connections[conn]; ok { - subscriptionStore = &candidate.Subscriptions + if wrapper, ok := w.connections[conn]; ok { + subscriptionStore = &wrapper.Subscriptions } else { subscriptionStore = &w.subscriptions } @@ -1150,8 +1155,8 @@ func (w *Websocket) RemoveSubscriptions(conn Connection, subs ...*subscription.S } var subscriptionStore *subscription.Store - if candidate, ok := w.connections[conn]; ok { - subscriptionStore = candidate.Subscriptions + if wrapper, ok := w.connections[conn]; ok { + subscriptionStore = wrapper.Subscriptions } else { subscriptionStore = w.subscriptions } @@ -1237,8 +1242,8 @@ func checkWebsocketURL(s string) error { // The subscription state is not considered when counting existing subscriptions func (w *Websocket) checkSubscriptions(conn Connection, subs subscription.List) error { var subscriptionStore *subscription.Store - if candidate, ok := w.connections[conn]; ok { - subscriptionStore = candidate.Subscriptions + if wrapper, ok := w.connections[conn]; ok { + subscriptionStore = wrapper.Subscriptions } else { subscriptionStore = w.subscriptions } diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index ab66c11a7bc..cee93985dc8 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -143,13 +143,13 @@ func (w *WebsocketConnection) SetupPingHandler(epl request.EndpointLimit, handle ticker := time.NewTicker(handler.Delay) for { select { - case <-w.ShutdownC: + case <-w.shutdown: ticker.Stop() return case <-ticker.C: err := w.SendRawMessage(context.TODO(), epl, handler.MessageType, handler.Message) if err != nil { - log.Errorf(log.WebsocketMgr, "%v websocket connection: ping handler failed to send message [%s]", w.ExchangeName, handler.Message) + log.Errorf(log.WebsocketMgr, "%v websocket connection: ping handler failed to send message [%s]: %v", w.ExchangeName, handler.Message, err) return } } @@ -272,6 +272,7 @@ func (w *WebsocketConnection) Shutdown() error { return nil } w.setConnectedStatus(false) + close(w.shutdown) w.writeControl.Lock() defer w.writeControl.Unlock() return w.Connection.NetConn().Close() diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index c5c6dd190dd..ddd98e4ef83 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -7,7 +7,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "net/http" "net/http/httptest" "os" @@ -95,26 +94,10 @@ var defaultSetup = &WebsocketSetup{ Features: &protocol.Features{Subscribe: true, Unsubscribe: true}, } -type dodgyConnection struct { - WebsocketConnection -} - -// override websocket connection method to produce a wicked terrible error -func (d *dodgyConnection) Shutdown() error { - return fmt.Errorf("%w: %w", errCannotShutdown, errDastardlyReason) -} - -// override websocket connection method to produce a wicked terrible error -func (d *dodgyConnection) Connect() error { - return fmt.Errorf("cannot connect: %w", errDastardlyReason) -} - func TestMain(m *testing.M) { // Change trafficCheckInterval for TestTrafficMonitorTimeout before parallel tests to avoid racing trafficCheckInterval = 50 * time.Millisecond - r := m.Run() - - os.Exit(r) + os.Exit(m.Run()) } func TestSetup(t *testing.T) { @@ -967,7 +950,7 @@ func TestSetupPingHandler(t *testing.T) { if wc.ProxyURL != "" && !useProxyTests { t.Skip("Proxy testing not enabled, skipping") } - wc.ShutdownC = make(chan struct{}) + wc.shutdown = make(chan struct{}) err := wc.Dial(&dialer, http.Header{}) if err != nil { t.Fatal(err) @@ -994,7 +977,7 @@ func TestSetupPingHandler(t *testing.T) { Delay: 200, }) time.Sleep(time.Millisecond * 201) - close(wc.ShutdownC) + close(wc.shutdown) wc.Wg.Wait() } @@ -1061,7 +1044,7 @@ func TestGenerateMessageID(t *testing.T) { assert.EqualValues(t, 42, wc.GenerateMessageID(true), "GenerateMessageID must use bespokeGenerateMessageID") } -// BenchmarkGenerateMessageID-8 2850018 408 ns/op 56 B/op 4 allocs/op +// 7002502 166.7 ns/op 48 B/op 3 allocs/op func BenchmarkGenerateMessageID_High(b *testing.B) { wc := WebsocketConnection{} for i := 0; i < b.N; i++ { @@ -1069,7 +1052,7 @@ func BenchmarkGenerateMessageID_High(b *testing.B) { } } -// BenchmarkGenerateMessageID_Low-8 2591596 447 ns/op 56 B/op 4 allocs/op +// 6536250 186.1 ns/op 48 B/op 3 allocs/op func BenchmarkGenerateMessageID_Low(b *testing.B) { wc := WebsocketConnection{} for i := 0; i < b.N; i++ { @@ -1184,10 +1167,6 @@ func connect() error { return nil } func TestFlushChannels(t *testing.T) { t.Parallel() // Enabled pairs/setup system - newgen := GenSubs{EnabledPairs: []currency.Pair{ - currency.NewPair(currency.BTC, currency.AUD), - currency.NewPair(currency.BTC, currency.USDT), - }} dodgyWs := Websocket{} err := dodgyWs.FlushChannels() @@ -1197,6 +1176,11 @@ func TestFlushChannels(t *testing.T) { err = dodgyWs.FlushChannels() assert.ErrorIs(t, err, ErrNotConnected, "FlushChannels should error correctly") + newgen := GenSubs{EnabledPairs: []currency.Pair{ + currency.NewPair(currency.BTC, currency.AUD), + currency.NewPair(currency.BTC, currency.USDT), + }} + w := NewWebsocket() w.connector = connect w.Subscriber = newgen.SUBME @@ -1207,14 +1191,6 @@ func TestFlushChannels(t *testing.T) { w.setEnabled(true) w.setState(connectedState) - problemFunc := func() (subscription.List, error) { - return nil, errDastardlyReason - } - - noSub := func() (subscription.List, error) { - return nil, nil - } - // Disable pair and flush system newgen.EnabledPairs = []currency.Pair{ currency.NewPair(currency.BTC, currency.AUD)} @@ -1224,13 +1200,13 @@ func TestFlushChannels(t *testing.T) { err = w.FlushChannels() require.NoError(t, err, "Flush Channels must not error") - w.features.FullPayloadSubscribe = true - w.GenerateSubs = problemFunc - err = w.FlushChannels() // error on full subscribeToChannels + // w.features.FullPayloadSubscribe = true + w.GenerateSubs = func() (subscription.List, error) { return nil, errDastardlyReason } // error on generateSubs + err = w.FlushChannels() // error on full subscribeToChannels assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly on GenerateSubs") - w.GenerateSubs = noSub - err = w.FlushChannels() // No subs to sub + w.GenerateSubs = func() (subscription.List, error) { return nil, nil } // No subs to sub + err = w.FlushChannels() // No subs to sub assert.NoError(t, err, "Flush Channels should not error") w.GenerateSubs = newgen.generateSubs @@ -1239,7 +1215,6 @@ func TestFlushChannels(t *testing.T) { require.NoError(t, w.AddSubscriptions(nil, subs...), "AddSubscriptions must not error") err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - w.features.FullPayloadSubscribe = false w.features.Subscribe = true w.GenerateSubs = newgen.generateSubs @@ -1268,9 +1243,15 @@ func TestFlushChannels(t *testing.T) { // Multi connection management w.useMultiConnectionManagement = true w.exchangeName = "multi" + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) + defer mock.Close() + amazingCandidate := &ConnectionSetup{ - URL: "AMAZING", - Connector: func(context.Context, Connection) error { return nil }, + URL: "ws" + mock.URL[len("http"):] + "/ws", + Connector: func(ctx context.Context, conn Connection) error { + return conn.DialContext(ctx, websocket.DefaultDialer, nil) + }, GenerateSubscriptions: newgen.generateSubs, Subscriber: func(ctx context.Context, c Connection, s subscription.List) error { return currySimpleSubConn(w)(ctx, c, s) @@ -1283,8 +1264,14 @@ func TestFlushChannels(t *testing.T) { require.NoError(t, w.SetupNewConnection(amazingCandidate)) require.NoError(t, w.FlushChannels(), "FlushChannels must not error") + // Forces full connection cycle w.features.Subscribe = false - w.features.FullPayloadSubscribe = true + require.NoError(t, w.FlushChannels(), "FlushChannels must not error") + + // Unsubscribe whats already subscribed. No subscriptions left over, which then forces the shutdown and removal + // of the connection from management. + w.features.Subscribe = true + w.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil } require.NoError(t, w.FlushChannels(), "FlushChannels must not error") } @@ -1380,12 +1367,12 @@ func TestSetupNewConnection(t *testing.T) { require.Nil(t, multi.Conn) err = multi.SetupNewConnection(connSetup) - require.ErrorIs(t, err, errConnectionCandidateDuplication) + require.ErrorIs(t, err, errConnectionWrapperDuplication) } func TestWebsocketConnectionShutdown(t *testing.T) { t.Parallel() - wc := WebsocketConnection{} + wc := WebsocketConnection{shutdown: make(chan struct{})} err := wc.Shutdown() assert.NoError(t, err, "Shutdown should not error") @@ -1512,79 +1499,3 @@ func TestWriteToConn(t *testing.T) { wc.RateLimit = nil require.ErrorIs(t, wc.writeToConn(ctx, request.Unset, func() error { return nil }), errRateLimitNotFound) } - -func TestGenerateUnsubscribeAndSubscribe(t *testing.T) { - t.Parallel() - ws := Websocket{subscriptions: subscription.NewStore(), features: &protocol.Features{}} - require.NoError(t, ws.subscriptions.Add(&subscription.Subscription{Channel: subscription.MyOrdersChannel})) - - generateError := errors.New("foo fighters the generator") - err := ws.generateUnsubscribeAndSubscribe(&WebsocketConnection{}, func() (subscription.List, error) { - return nil, generateError - }) - require.ErrorIs(t, err, generateError) - - err = ws.generateUnsubscribeAndSubscribe(&WebsocketConnection{}, func() (subscription.List, error) { - return subscription.List{{Channel: subscription.CandlesChannel}, {Channel: subscription.OrderbookChannel}}, nil - }) - require.ErrorIs(t, err, common.ErrNilPointer) - - failedSubscriberError := errors.New("failed subscriber") - ws.Subscriber = func(subscription.List) error { return failedSubscriberError } - err = ws.generateUnsubscribeAndSubscribe(&WebsocketConnection{}, func() (subscription.List, error) { - return subscription.List{{Channel: subscription.CandlesChannel}, {Channel: subscription.OrderbookChannel}}, nil - }) - require.ErrorIs(t, err, failedSubscriberError) - - failedUnSubscriberError := errors.New("failed unsubscriber") - ws.Subscriber = func(subscription.List) error { return nil } - ws.Unsubscriber = func(subscription.List) error { return failedUnSubscriberError } - ws.features.Unsubscribe = true - err = ws.generateUnsubscribeAndSubscribe(&WebsocketConnection{}, func() (subscription.List, error) { - return subscription.List{{Channel: subscription.CandlesChannel}, {Channel: subscription.OrderbookChannel}}, nil - }) - require.ErrorIs(t, err, failedUnSubscriberError) - - ws.Unsubscriber = func(subscription.List) error { return nil } - err = ws.generateUnsubscribeAndSubscribe(&WebsocketConnection{}, func() (subscription.List, error) { - return subscription.List{{Channel: subscription.CandlesChannel}, {Channel: subscription.OrderbookChannel}}, nil - }) - require.NoError(t, err) - - ws.Unsubscriber = func(subscription.List) error { return failedUnSubscriberError } - ws.Subscriber = func(subscription.List) error { return failedSubscriberError } - err = ws.generateUnsubscribeAndSubscribe(&WebsocketConnection{}, func() (subscription.List, error) { - return subscription.List{{Channel: subscription.MyOrdersChannel}}, nil - }) - require.NoError(t, err) -} - -func TestGenerateAndSubscribe(t *testing.T) { - t.Parallel() - - ws := Websocket{subscriptions: subscription.NewStore()} - - generateError := errors.New("foo fighters the generator") - err := ws.generateAndSubscribe(ws.subscriptions, &WebsocketConnection{}, func() (subscription.List, error) { - return nil, generateError - }) - require.ErrorIs(t, err, generateError) - - ws.Subscriber = func(subscription.List) error { return nil } - err = ws.generateAndSubscribe(ws.subscriptions, &WebsocketConnection{}, func() (subscription.List, error) { - return subscription.List{{Channel: subscription.CandlesChannel}, {Channel: subscription.OrderbookChannel}}, nil - }) - require.NoError(t, err) - - failedSubscriberError := errors.New("failed subscriber") - ws.Subscriber = func(subscription.List) error { return failedSubscriberError } - err = ws.generateAndSubscribe(ws.subscriptions, &WebsocketConnection{}, func() (subscription.List, error) { - return subscription.List{{Channel: subscription.CandlesChannel}, {Channel: subscription.OrderbookChannel}}, nil - }) - require.ErrorIs(t, err, failedSubscriberError) - - err = ws.generateAndSubscribe(ws.subscriptions, &WebsocketConnection{}, func() (subscription.List, error) { - return nil, nil - }) - require.NoError(t, err) -} diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index c753b89aea3..e903e997157 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -171,7 +171,9 @@ type WebsocketConnection struct { ProxyURL string Wg *sync.WaitGroup Connection *websocket.Conn - ShutdownC chan struct{} + + // shutdown synchronises shutdown event across routines associated with this connection only e.g. ping handler + shutdown chan struct{} Match *Match ResponseMaxLimit time.Duration From b877b385cb52948f3845dcf2cfd85b2b86ce3f65 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Mon, 16 Sep 2024 14:32:43 +1000 Subject: [PATCH 066/138] spella: fix --- exchanges/stream/websocket_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index ddd98e4ef83..3e63157400d 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -1268,7 +1268,7 @@ func TestFlushChannels(t *testing.T) { w.features.Subscribe = false require.NoError(t, w.FlushChannels(), "FlushChannels must not error") - // Unsubscribe whats already subscribed. No subscriptions left over, which then forces the shutdown and removal + // Unsubscribe what's already subscribed. No subscriptions left over, which then forces the shutdown and removal // of the connection from management. w.features.Subscribe = true w.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil } From f7d1ec827cee1129d4039e5673c04f1a04bd970b Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 19 Sep 2024 09:24:11 +1000 Subject: [PATCH 067/138] linter: fix --- exchanges/stream/websocket_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 3e63157400d..b8947267ac6 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -1200,7 +1200,6 @@ func TestFlushChannels(t *testing.T) { err = w.FlushChannels() require.NoError(t, err, "Flush Channels must not error") - // w.features.FullPayloadSubscribe = true w.GenerateSubs = func() (subscription.List, error) { return nil, errDastardlyReason } // error on generateSubs err = w.FlushChannels() // error on full subscribeToChannels assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly on GenerateSubs") From bdc7afb27f6d9e100d90349c9489d5e27098cbad Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 19 Sep 2024 16:31:47 +1000 Subject: [PATCH 068/138] Add glorious suggestions, fix some race thing --- exchanges/stream/README.md | 19 +++++++++++++++---- exchanges/stream/websocket.go | 14 +++++++------- exchanges/stream/websocket_test.go | 2 -- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/exchanges/stream/README.md b/exchanges/stream/README.md index 65d16fc9fc7..3a02c2efcfd 100644 --- a/exchanges/stream/README.md +++ b/exchanges/stream/README.md @@ -4,15 +4,23 @@ This package is part of the GoCryptoTrader project and is responsible for handli ## Overview -The `stream` package provides functionalities to connect to various cryptocurrency exchanges and handle real-time data streams. +The `stream` package uses Gorilla Websocket and provides functionalities to connect to various cryptocurrency exchanges and handle real-time data streams. ## Features - Handle real-time market data streams - Unified interface for managing data streams +- Multi-connection management - a system that can be used to manage multiple connections to the same exchange +- Connection monitoring - a system that can be used to monitor the health of the websocket connections. This can be used to check if the connection is still alive and if it is not, it will attempt to reconnect +- Traffic monitoring - will reconnect if no message is sent for a period of time defined in your config +- Subscription management - a system that can be used to manage subscriptions to various data streams +- Rate limiting - a system that can be used to rate limit the number of requests sent to the exchange +- Message ID generation - a system that can be used to generate message IDs for websocket requests +- Websocket message response matching - can be used to match websocket responses to the requests that were sent ## Usage +### Default single websocket connection Here is a basic example of how to setup the `stream` package for websocket: ```go @@ -73,9 +81,12 @@ func (e *Exchange) Setup(exch *config.Exchange) error { RateLimit: request.NewRateLimitWithWeight(time.Second, 2, 1), }) } +``` -// The example below provides the now optional multi connection management system which allows for more connections -// to be maintained and established based off URL, connections types, asset types etc. +### Multiple websocket connections + The example below provides the now optional multi connection management system which allows for more connections + to be maintained and established based off URL, connections types, asset types etc. +```go func (e *Exchange) Setup(exch *config.Exchange) error { // This sets up global connection, sub, unsub and generate subscriptions for each connection defined below. if err := e.Websocket.Setup(&stream.WebsocketSetup{ @@ -123,4 +134,4 @@ func (e *Exchange) Setup(exch *config.Exchange) error { return err } } -``` +``` \ No newline at end of file diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 1803d67438a..6f1744c9bc2 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -1013,7 +1013,7 @@ func (w *Websocket) GetName() string { // and the new subscription list when pairs are disabled or enabled. func (w *Websocket) GetChannelDifference(conn Connection, newSubs subscription.List) (sub, unsub subscription.List) { var subscriptionStore **subscription.Store - if wrapper, ok := w.connections[conn]; ok { + if wrapper, ok := w.connections[conn]; ok && conn != nil { subscriptionStore = &wrapper.Subscriptions } else { subscriptionStore = &w.subscriptions @@ -1029,7 +1029,7 @@ func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.L if len(channels) == 0 { return nil // No channels to unsubscribe from is not an error } - if wrapper, ok := w.connections[conn]; ok { + if wrapper, ok := w.connections[conn]; ok && conn != nil { return w.unsubscribe(wrapper.Subscriptions, channels, func(channels subscription.List) error { return wrapper.Setup.Unsubscriber(context.TODO(), conn, channels) }) @@ -1075,7 +1075,7 @@ func (w *Websocket) SubscribeToChannels(conn Connection, subs subscription.List) return err } - if wrapper, ok := w.connections[conn]; ok { + if wrapper, ok := w.connections[conn]; ok && conn != nil { return wrapper.Setup.Subscriber(context.TODO(), conn, subs) } @@ -1096,7 +1096,7 @@ func (w *Websocket) AddSubscriptions(conn Connection, subs ...*subscription.Subs return fmt.Errorf("%w: AddSubscriptions called on nil Websocket", common.ErrNilPointer) } var subscriptionStore **subscription.Store - if wrapper, ok := w.connections[conn]; ok { + if wrapper, ok := w.connections[conn]; ok && conn != nil { subscriptionStore = &wrapper.Subscriptions } else { subscriptionStore = &w.subscriptions @@ -1126,7 +1126,7 @@ func (w *Websocket) AddSuccessfulSubscriptions(conn Connection, subs ...*subscri } var subscriptionStore **subscription.Store - if wrapper, ok := w.connections[conn]; ok { + if wrapper, ok := w.connections[conn]; ok && conn != nil { subscriptionStore = &wrapper.Subscriptions } else { subscriptionStore = &w.subscriptions @@ -1155,7 +1155,7 @@ func (w *Websocket) RemoveSubscriptions(conn Connection, subs ...*subscription.S } var subscriptionStore *subscription.Store - if wrapper, ok := w.connections[conn]; ok { + if wrapper, ok := w.connections[conn]; ok && conn != nil { subscriptionStore = wrapper.Subscriptions } else { subscriptionStore = w.subscriptions @@ -1242,7 +1242,7 @@ func checkWebsocketURL(s string) error { // The subscription state is not considered when counting existing subscriptions func (w *Websocket) checkSubscriptions(conn Connection, subs subscription.List) error { var subscriptionStore *subscription.Store - if wrapper, ok := w.connections[conn]; ok { + if wrapper, ok := w.connections[conn]; ok && conn != nil { subscriptionStore = wrapper.Subscriptions } else { subscriptionStore = w.subscriptions diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index b8947267ac6..d9e5b5d6a16 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -1241,8 +1241,6 @@ func TestFlushChannels(t *testing.T) { // Multi connection management w.useMultiConnectionManagement = true - w.exchangeName = "multi" - mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) defer mock.Close() From 3d6541e4546f392b2110dd642e351782c1c4e17d Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 20 Sep 2024 08:52:55 +1000 Subject: [PATCH 069/138] reinstate name before any routine gets spawned --- exchanges/stream/websocket_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index d9e5b5d6a16..bef2466b81e 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -1182,6 +1182,7 @@ func TestFlushChannels(t *testing.T) { }} w := NewWebsocket() + w.exchangeName = "test" w.connector = connect w.Subscriber = newgen.SUBME w.Unsubscriber = newgen.UNSUBME From 33c4128bc5ad380ae17e289a2926d1b4b510a770 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 20 Sep 2024 08:56:56 +1000 Subject: [PATCH 070/138] stop on error in mock tests --- internal/testing/websocket/mock.go | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/internal/testing/websocket/mock.go b/internal/testing/websocket/mock.go index bbccf4ab2cf..fb93d9becbb 100644 --- a/internal/testing/websocket/mock.go +++ b/internal/testing/websocket/mock.go @@ -2,7 +2,6 @@ package websocket import ( "net/http" - "strings" "testing" "github.com/gorilla/websocket" @@ -32,17 +31,10 @@ func WsMockUpgrader(tb testing.TB, w http.ResponseWriter, r *http.Request, wsHan defer c.Close() for { _, p, err := c.ReadMessage() - if websocket.IsUnexpectedCloseError(err) { + if err != nil { + // Any error here is likely due to the connection closing return } - - if err != nil && (strings.Contains(err.Error(), "wsarecv: An established connection was aborted by the software in your host machine.") || - strings.Contains(err.Error(), "wsarecv: An existing connection was forcibly closed by the remote host.")) { - return - } - - require.NoError(tb, err, "ReadMessage should not error") - err = wsHandler(p, c) assert.NoError(tb, err, "WS Mock Function should not error") } From ea257634939a8fb48c8c7265135d67c073dbd861 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 20 Sep 2024 09:20:55 +1000 Subject: [PATCH 071/138] glorious: nits --- exchanges/gateio/gateio_websocket.go | 15 ++++------ .../gateio/gateio_ws_delivery_futures.go | 29 ++++++++----------- exchanges/gateio/gateio_ws_futures.go | 22 +++++++------- exchanges/gateio/gateio_ws_option.go | 23 ++++++++------- 4 files changed, 42 insertions(+), 47 deletions(-) diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 27d51a1dcd6..ea4e473921a 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -631,20 +631,17 @@ func (g *Gateio) GenerateDefaultSubscriptionsSpot() (subscription.List, error) { switch channelsToSubscribe[i] { case marginBalancesChannel: assetType = asset.Margin - if pairs, err = g.GetEnabledPairs(asset.Margin); err != nil && !errors.Is(err, asset.ErrNotEnabled) { - return nil, err - } + pairs, err = g.GetEnabledPairs(asset.Margin) case crossMarginBalanceChannel: assetType = asset.CrossMargin - if pairs, err = g.GetEnabledPairs(asset.CrossMargin); err != nil && !errors.Is(err, asset.ErrNotEnabled) { - return nil, err - } + pairs, err = g.GetEnabledPairs(asset.CrossMargin) default: // TODO: Check and add balance support as spot balances can be subscribed without a currency pair supplied. assetType = asset.Spot - if pairs, err = g.GetEnabledPairs(asset.Spot); err != nil && !errors.Is(err, asset.ErrNotEnabled) { - return nil, err - } + pairs, err = g.GetEnabledPairs(asset.Spot) + } + if err != nil && !errors.Is(err, asset.ErrNotEnabled) { + return nil, err } for j := range pairs { diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_ws_delivery_futures.go index 3109eca9a38..02225c34ce0 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_ws_delivery_futures.go @@ -72,26 +72,21 @@ func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() (subscription.Lis } channelsToSubscribe := defaultDeliveryFuturesSubscriptions if g.Websocket.CanUseAuthenticatedEndpoints() { - channelsToSubscribe = append( - channelsToSubscribe, - futuresOrdersChannel, - futuresUserTradesChannel, - futuresBalancesChannel, - ) + channelsToSubscribe = append(channelsToSubscribe, futuresOrdersChannel, futuresUserTradesChannel, futuresBalancesChannel) } - var subscriptions subscription.List - for i := range channelsToSubscribe { - pairs, err := g.GetEnabledPairs(asset.DeliveryFutures) - if err != nil { - if errors.Is(err, asset.ErrNotEnabled) { - continue // skip if not enabled - } - return nil, err + pairs, err := g.GetEnabledPairs(asset.DeliveryFutures) + if err != nil { + if errors.Is(err, asset.ErrNotEnabled) { + return nil, nil // no enabled pairs, subscriptions require an associated pair. } + return nil, err + } + var subscriptions subscription.List + for i := range channelsToSubscribe { for j := range pairs { - params := make(map[string]interface{}) + params := make(map[string]any) switch channelsToSubscribe[i] { case futuresOrderbookChannel: params["limit"] = 20 @@ -99,13 +94,13 @@ func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() (subscription.Lis case futuresCandlesticksChannel: params["interval"] = kline.FiveMin } - fpair, err := g.FormatExchangeCurrency(pairs[j], asset.DeliveryFutures) + fPair, err := g.FormatExchangeCurrency(pairs[j], asset.DeliveryFutures) if err != nil { return nil, err } subscriptions = append(subscriptions, &subscription.Subscription{ Channel: channelsToSubscribe[i], - Pairs: currency.Pairs{fpair.Upper()}, + Pairs: currency.Pairs{fPair.Upper()}, Params: params, }) } diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index 58417e37e9b..820db1d1682 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -95,16 +95,16 @@ func (g *Gateio) GenerateFuturesDefaultSubscriptions(settlement currency.Code) ( ) } - var subscriptions subscription.List - for i := range channelsToSubscribe { - pairs, err := g.GetEnabledPairs(asset.Futures) - if err != nil { - if errors.Is(err, asset.ErrNotEnabled) { - continue // skip if not enabled - } - return nil, err + pairs, err := g.GetEnabledPairs(asset.Futures) + if err != nil { + if errors.Is(err, asset.ErrNotEnabled) { + return nil, nil // no enabled pairs, subscriptions require an associated pair. } + return nil, err + } + var subscriptions subscription.List + for i := range channelsToSubscribe { switch { case settlement.Equal(currency.USDT): pairs, err = pairs.GetPairsByQuote(currency.USDT) @@ -126,7 +126,7 @@ func (g *Gateio) GenerateFuturesDefaultSubscriptions(settlement currency.Code) ( } for j := range pairs { - params := make(map[string]interface{}) + params := make(map[string]any) switch channelsToSubscribe[i] { case futuresOrderbookChannel: params["limit"] = 100 @@ -137,13 +137,13 @@ func (g *Gateio) GenerateFuturesDefaultSubscriptions(settlement currency.Code) ( params["frequency"] = kline.ThousandMilliseconds params["level"] = "100" } - fpair, err := g.FormatExchangeCurrency(pairs[j], asset.Futures) + fPair, err := g.FormatExchangeCurrency(pairs[j], asset.Futures) if err != nil { return nil, err } subscriptions = append(subscriptions, &subscription.Subscription{ Channel: channelsToSubscribe[i], - Pairs: currency.Pairs{fpair.Upper()}, + Pairs: currency.Pairs{fPair.Upper()}, Params: params, }) } diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_ws_option.go index 00d376a11e4..50b43612018 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_ws_option.go @@ -119,18 +119,21 @@ func (g *Gateio) GenerateOptionsDefaultSubscriptions() (subscription.List, error log.Errorf(log.ExchangeSys, "no subaccount found for authenticated options channel subscriptions") } } + getEnabledPairs: + + pairs, err := g.GetEnabledPairs(asset.Options) + if err != nil { + if errors.Is(err, asset.ErrNotEnabled) { + return nil, nil // no enabled pairs, subscriptions require an associated pair. + } + return nil, err + } + var subscriptions subscription.List for i := range channelsToSubscribe { - pairs, err := g.GetEnabledPairs(asset.Options) - if err != nil { - if errors.Is(err, asset.ErrNotEnabled) { - continue // skip if not enabled - } - return nil, err - } for j := range pairs { - params := make(map[string]interface{}) + params := make(map[string]any) switch channelsToSubscribe[i] { case optionsOrderbookChannel: params["accuracy"] = "0" @@ -152,13 +155,13 @@ getEnabledPairs: } params["user_id"] = userID } - fpair, err := g.FormatExchangeCurrency(pairs[j], asset.Options) + fPair, err := g.FormatExchangeCurrency(pairs[j], asset.Options) if err != nil { return nil, err } subscriptions = append(subscriptions, &subscription.Subscription{ Channel: channelsToSubscribe[i], - Pairs: currency.Pairs{fpair.Upper()}, + Pairs: currency.Pairs{fPair.Upper()}, Params: params, }) } From ae92cd2b00006389049c799c1376fc4a40e5d3ac Mon Sep 17 00:00:00 2001 From: shazbert Date: Tue, 24 Sep 2024 14:40:22 +1000 Subject: [PATCH 072/138] glorious: nits found in CI build --- exchanges/stream/websocket.go | 28 +++++++++++++--------------- exchanges/stream/websocket_test.go | 26 ++++++++++++-------------- 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 6f1744c9bc2..457c65b37ec 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -789,22 +789,11 @@ func (w *Websocket) trafficMonitor() { w.Wg.Done() return case <-time.After(trafficCheckInterval): - select { - case <-w.TrafficAlert: - if !t.Stop() { - <-t.C - } + if signalReceived(w.TrafficAlert) { t.Reset(w.trafficTimeout) - default: } case <-t.C: - checkAgain := w.IsConnecting() - select { - case <-w.TrafficAlert: - checkAgain = true - default: - } - if checkAgain { + if w.IsConnecting() || signalReceived(w.TrafficAlert) { t.Reset(w.trafficTimeout) break } @@ -814,8 +803,7 @@ func (w *Websocket) trafficMonitor() { w.setTrafficMonitorRunning(false) // Cannot defer lest Connect is called after Shutdown but before deferred call w.Wg.Done() // Without this the w.Shutdown() call below will deadlock if w.IsConnected() { - err := w.Shutdown() - if err != nil { + if err := w.Shutdown(); err != nil { log.Errorf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown err: %s", w.exchangeName, err) } } @@ -825,6 +813,16 @@ func (w *Websocket) trafficMonitor() { }() } +// signalReceived checks if a signal has been received, this also clears the signal. +func signalReceived(ch chan struct{}) bool { + select { + case <-ch: + return true + default: + return false + } +} + func (w *Websocket) setState(s uint32) { w.state.Store(s) } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index bef2466b81e..56c6d5aab87 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -250,17 +250,16 @@ func TestTrafficMonitorShutdown(t *testing.T) { ws.state.Store(connectedState) ws.trafficTimeout = time.Minute + ws.verbose = true + close(ws.TrafficAlert) // Mocks channel traffic signal. ws.trafficMonitor() assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") wgReady := make(chan bool) - go func() { - ws.Wg.Wait() - close(wgReady) - }() + go func() { ws.Wg.Wait(); close(wgReady) }() select { case <-wgReady: - require.Failf(t, "", "WaitGroup should be blocking still") + require.Fail(t, "WaitGroup should be blocking still") case <-time.After(trafficCheckInterval): } @@ -271,7 +270,7 @@ func TestTrafficMonitorShutdown(t *testing.T) { select { case <-wgReady: default: - require.Failf(t, "", "WaitGroup should be freed now") + require.Fail(t, "WaitGroup should be freed now") } } @@ -1192,12 +1191,13 @@ func TestFlushChannels(t *testing.T) { w.setEnabled(true) w.setState(connectedState) + // Allow subscribe and unsubscribe feature set, without these the tests will call shutdown and connect. + w.features.Subscribe = true + w.features.Unsubscribe = true + // Disable pair and flush system - newgen.EnabledPairs = []currency.Pair{ - currency.NewPair(currency.BTC, currency.AUD)} - w.GenerateSubs = func() (subscription.List, error) { - return subscription.List{{Channel: "test"}}, nil - } + newgen.EnabledPairs = []currency.Pair{currency.NewPair(currency.BTC, currency.AUD)} + w.GenerateSubs = func() (subscription.List, error) { return subscription.List{{Channel: "test"}}, nil } err = w.FlushChannels() require.NoError(t, err, "Flush Channels must not error") @@ -1215,7 +1215,6 @@ func TestFlushChannels(t *testing.T) { require.NoError(t, w.AddSubscriptions(nil, subs...), "AddSubscriptions must not error") err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - w.features.Subscribe = true w.GenerateSubs = newgen.generateSubs w.subscriptions = subscription.NewStore() @@ -1236,7 +1235,6 @@ func TestFlushChannels(t *testing.T) { assert.NoError(t, err, "FlushChannels should not error") w.setState(connectedState) - w.features.Unsubscribe = true err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") @@ -1262,7 +1260,7 @@ func TestFlushChannels(t *testing.T) { require.NoError(t, w.SetupNewConnection(amazingCandidate)) require.NoError(t, w.FlushChannels(), "FlushChannels must not error") - // Forces full connection cycle + // Forces full connection cycle (shutdown, connect, subscribe). This will also start monitoring routines. w.features.Subscribe = false require.NoError(t, w.FlushChannels(), "FlushChannels must not error") From c49a75c745d567896f6a337273cf042f29cf3201 Mon Sep 17 00:00:00 2001 From: shazbert Date: Tue, 24 Sep 2024 15:09:13 +1000 Subject: [PATCH 073/138] Add test for drain, bumped wait times as there seems to be something happening on macos CI builds, used context.WithTimeout because its instant. --- exchanges/stream/websocket_test.go | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 56c6d5aab87..42b2ee7d92a 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -238,7 +238,7 @@ func TestTrafficMonitorConnecting(t *testing.T) { require.EventuallyWithT(t, func(c *assert.CollectT) { assert.Equal(c, disconnectedState, ws.state.Load(), "websocket must be disconnected") assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") - }, 4*ws.trafficTimeout, 10*time.Millisecond, "trafficTimeout should trigger a shutdown after connecting status changes") + }, 8*ws.trafficTimeout, 10*time.Millisecond, "trafficTimeout should trigger a shutdown after connecting status changes") } // TestTrafficMonitorShutdown ensures shutdown is processed and waitgroup is cleared @@ -265,7 +265,7 @@ func TestTrafficMonitorShutdown(t *testing.T) { close(ws.ShutdownC) - <-time.After(2 * trafficCheckInterval) + <-time.After(4 * trafficCheckInterval) assert.False(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be shutdown") select { case <-wgReady: @@ -1480,10 +1480,9 @@ func TestWriteToConn(t *testing.T) { // connection rate limit set wc.RateLimit = request.NewWeightedRateLimitByDuration(time.Millisecond) require.NoError(t, wc.writeToConn(context.Background(), request.Unset, func() error { return nil })) - // context cancelled - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), 0) // deadline exceeded cancel() - require.ErrorIs(t, wc.writeToConn(ctx, request.Unset, func() error { return nil }), context.Canceled) + require.ErrorIs(t, wc.writeToConn(ctx, request.Unset, func() error { return nil }), context.DeadlineExceeded) // definitions set but with fallover wc.RateLimitDefinitions = request.RateLimitDefinitions{ request.Auth: request.NewWeightedRateLimitByDuration(time.Millisecond), @@ -1495,3 +1494,17 @@ func TestWriteToConn(t *testing.T) { wc.RateLimit = nil require.ErrorIs(t, wc.writeToConn(ctx, request.Unset, func() error { return nil }), errRateLimitNotFound) } + +func TestDrain(t *testing.T) { + t.Parallel() + drain(nil) + ch := make(chan error) + drain(ch) + require.Empty(t, ch, "Drain should empty the channel") + ch = make(chan error, 10) + for i := 0; i < 10; i++ { + ch <- errors.New("test") + } + drain(ch) + require.Empty(t, ch, "Drain should empty the channel") +} From 5551d547703e218003bb088f845ad16907825209 Mon Sep 17 00:00:00 2001 From: shazbert Date: Tue, 24 Sep 2024 15:21:44 +1000 Subject: [PATCH 074/138] mutex across shutdown and connect for protection --- exchanges/stream/websocket.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 457c65b37ec..c968e891684 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -318,7 +318,10 @@ func (w *Websocket) getConnectionFromSetup(c *ConnectionSetup) *WebsocketConnect func (w *Websocket) Connect() error { w.m.Lock() defer w.m.Unlock() + return w.connect() +} +func (w *Websocket) connect() error { if !w.IsEnabled() { return ErrWebsocketNotEnabled } @@ -613,7 +616,10 @@ func (w *Websocket) connectionMonitor() error { func (w *Websocket) Shutdown() error { w.m.Lock() defer w.m.Unlock() + return w.shutdown() +} +func (w *Websocket) shutdown() error { if !w.IsConnected() { return fmt.Errorf("%v %w: %w", w.exchangeName, errCannotShutdown, ErrNotConnected) } @@ -697,10 +703,12 @@ func (w *Websocket) FlushChannels() error { // If the exchange does not support subscribing and or unsubscribing the full connection needs to be flushed to // maintain consistency. if !w.features.Subscribe || !w.features.Unsubscribe { - if err := w.Shutdown(); err != nil { + w.m.Lock() + defer w.m.Unlock() + if err := w.shutdown(); err != nil { return err } - return w.Connect() + return w.connect() } if !w.useMultiConnectionManagement { From 99c3c7b28de9ad46628af04d6ca6541a79bc0131 Mon Sep 17 00:00:00 2001 From: shazbert Date: Tue, 24 Sep 2024 15:26:46 +1000 Subject: [PATCH 075/138] lint: fix --- exchanges/stream/websocket_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 42b2ee7d92a..be339154da1 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -1502,7 +1502,7 @@ func TestDrain(t *testing.T) { drain(ch) require.Empty(t, ch, "Drain should empty the channel") ch = make(chan error, 10) - for i := 0; i < 10; i++ { + for range 10 { ch <- errors.New("test") } drain(ch) From f56cb78815e9baecab692c1137554f4b4f6a21cc Mon Sep 17 00:00:00 2001 From: shazbert Date: Tue, 24 Sep 2024 16:14:44 +1000 Subject: [PATCH 076/138] test time withoffset, reinstate stop --- exchanges/stream/websocket.go | 1 + exchanges/stream/websocket_test.go | 10 +++------- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index c968e891684..091fb2bc66e 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -798,6 +798,7 @@ func (w *Websocket) trafficMonitor() { return case <-time.After(trafficCheckInterval): if signalReceived(w.TrafficAlert) { + t.Stop() t.Reset(w.trafficTimeout) } case <-t.C: diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index be339154da1..96de564178e 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -194,9 +194,7 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) { for i := range 6 { // Timeout will happen at 200ms so we want 6 * 50ms checks to pass select { case ws.TrafficAlert <- signal: - if i == 0 { - require.WithinDurationf(t, time.Now(), thenish, trafficCheckInterval, "First Non-blocking test must happen before the traffic is checked") - } + require.WithinDuration(t, thenish.Add(time.Duration(i)*trafficCheckInterval), time.Now(), trafficCheckInterval) default: require.Failf(t, "", "TrafficAlert should not block; Check #%d", i) } @@ -205,12 +203,10 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) { case ws.TrafficAlert <- signal: require.Failf(t, "", "TrafficAlert should block after first slot used; Check #%d", i) default: - if i == 0 { - require.WithinDuration(t, time.Now(), thenish, trafficCheckInterval, "First Blocking test must happen before the traffic is checked") - } + require.WithinDuration(t, time.Now(), thenish.Add(time.Duration(i)*trafficCheckInterval), trafficCheckInterval) } - require.Eventuallyf(t, func() bool { return len(ws.TrafficAlert) == 0 }, 5*time.Second, patience, "trafficAlert should be drained; Check #%d", i) + require.Eventuallyf(t, func() bool { return len(ws.TrafficAlert) == 0 }, 2*trafficCheckInterval, patience, "trafficAlert should be drained; Check #%d", i) assert.Truef(t, ws.IsConnected(), "state should still be connected; Check #%d", i) } From b4bc7e2ecb0c3fae554c45e065b0fca55e772c53 Mon Sep 17 00:00:00 2001 From: shazbert Date: Tue, 24 Sep 2024 16:25:46 +1000 Subject: [PATCH 077/138] fix whoops --- exchanges/stream/websocket_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 96de564178e..26d5fdd1c9b 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -194,7 +194,7 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) { for i := range 6 { // Timeout will happen at 200ms so we want 6 * 50ms checks to pass select { case ws.TrafficAlert <- signal: - require.WithinDuration(t, thenish.Add(time.Duration(i)*trafficCheckInterval), time.Now(), trafficCheckInterval) + require.WithinDuration(t, time.Now(), thenish.Add(time.Duration(i)*trafficCheckInterval), trafficCheckInterval) default: require.Failf(t, "", "TrafficAlert should not block; Check #%d", i) } From ce4a5ce1768e8d138e0134301bfb415d0c9cf505 Mon Sep 17 00:00:00 2001 From: shazbert Date: Tue, 24 Sep 2024 16:44:24 +1000 Subject: [PATCH 078/138] const trafficCheckInterval; rm testmain --- exchanges/stream/websocket.go | 7 +++---- exchanges/stream/websocket_test.go | 7 ------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 091fb2bc66e..c8c7424dad2 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -69,10 +69,9 @@ var ( errExchangeConfigEmpty = errors.New("exchange config is empty") ) -var ( - globalReporter Reporter - trafficCheckInterval = 100 * time.Millisecond -) +var globalReporter Reporter + +const trafficCheckInterval = 100 * time.Millisecond // SetupGlobalReporter sets a reporter interface to be used // for all exchange requests diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 26d5fdd1c9b..0f173ca3255 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -9,7 +9,6 @@ import ( "errors" "net/http" "net/http/httptest" - "os" "strconv" "strings" "sync" @@ -94,12 +93,6 @@ var defaultSetup = &WebsocketSetup{ Features: &protocol.Features{Subscribe: true, Unsubscribe: true}, } -func TestMain(m *testing.M) { - // Change trafficCheckInterval for TestTrafficMonitorTimeout before parallel tests to avoid racing - trafficCheckInterval = 50 * time.Millisecond - os.Exit(m.Run()) -} - func TestSetup(t *testing.T) { t.Parallel() var w *Websocket From a745992687adcb58bd7e43ea153ce80a9255e2b4 Mon Sep 17 00:00:00 2001 From: shazbert Date: Tue, 24 Sep 2024 17:06:13 +1000 Subject: [PATCH 079/138] y --- exchanges/stream/websocket.go | 3 +-- exchanges/stream/websocket_test.go | 10 ++++------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index c8c7424dad2..87b2ab9a9bd 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -797,7 +797,6 @@ func (w *Websocket) trafficMonitor() { return case <-time.After(trafficCheckInterval): if signalReceived(w.TrafficAlert) { - t.Stop() t.Reset(w.trafficTimeout) } case <-t.C: @@ -824,7 +823,7 @@ func (w *Websocket) trafficMonitor() { // signalReceived checks if a signal has been received, this also clears the signal. func signalReceived(ch chan struct{}) bool { select { - case <-ch: + case _ = <-ch: return true default: return false diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 0f173ca3255..10d120f19cf 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -173,8 +173,6 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) { err := ws.Setup(defaultSetup) require.NoError(t, err, "Setup must not error") - signal := struct{}{} - patience := 10 * time.Millisecond ws.trafficTimeout = 200 * time.Millisecond ws.state.Store(connectedState) @@ -186,27 +184,27 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) { for i := range 6 { // Timeout will happen at 200ms so we want 6 * 50ms checks to pass select { - case ws.TrafficAlert <- signal: + case ws.TrafficAlert <- struct{}{}: require.WithinDuration(t, time.Now(), thenish.Add(time.Duration(i)*trafficCheckInterval), trafficCheckInterval) default: require.Failf(t, "", "TrafficAlert should not block; Check #%d", i) } select { - case ws.TrafficAlert <- signal: + case ws.TrafficAlert <- struct{}{}: require.Failf(t, "", "TrafficAlert should block after first slot used; Check #%d", i) default: require.WithinDuration(t, time.Now(), thenish.Add(time.Duration(i)*trafficCheckInterval), trafficCheckInterval) } - require.Eventuallyf(t, func() bool { return len(ws.TrafficAlert) == 0 }, 2*trafficCheckInterval, patience, "trafficAlert should be drained; Check #%d", i) + require.Eventuallyf(t, func() bool { return len(ws.TrafficAlert) == 0 }, 2*trafficCheckInterval, time.Millisecond, "trafficAlert should be drained; Check #%d", i) assert.Truef(t, ws.IsConnected(), "state should still be connected; Check #%d", i) } require.EventuallyWithT(t, func(c *assert.CollectT) { assert.Equal(c, disconnectedState, ws.state.Load(), "websocket must be disconnected") assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") - }, 2*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown once we stop feeding trafficAlerts") + }, 2*ws.trafficTimeout, time.Millisecond, "trafficTimeout should trigger a shutdown once we stop feeding trafficAlerts") } // TestTrafficMonitorConnecting ensures connecting status doesn't trigger shutdown From 086fcc3349f2ebf2a4bf672f2f737c9f0a104e00 Mon Sep 17 00:00:00 2001 From: shazbert Date: Tue, 24 Sep 2024 17:11:35 +1000 Subject: [PATCH 080/138] fix lint --- exchanges/stream/websocket.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 87b2ab9a9bd..2319b7153e7 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -823,7 +823,7 @@ func (w *Websocket) trafficMonitor() { // signalReceived checks if a signal has been received, this also clears the signal. func signalReceived(ch chan struct{}) bool { select { - case _ = <-ch: + case <-ch: return true default: return false From a64e84224e6e4d2ec47a17d019f642b2395a0571 Mon Sep 17 00:00:00 2001 From: shazbert Date: Tue, 24 Sep 2024 17:22:55 +1000 Subject: [PATCH 081/138] bump time check window --- exchanges/stream/websocket_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 10d120f19cf..2280716865b 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -185,7 +185,7 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) { for i := range 6 { // Timeout will happen at 200ms so we want 6 * 50ms checks to pass select { case ws.TrafficAlert <- struct{}{}: - require.WithinDuration(t, time.Now(), thenish.Add(time.Duration(i)*trafficCheckInterval), trafficCheckInterval) + require.WithinDuration(t, time.Now(), thenish.Add(time.Duration(i)*trafficCheckInterval), trafficCheckInterval*2) default: require.Failf(t, "", "TrafficAlert should not block; Check #%d", i) } @@ -194,7 +194,7 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) { case ws.TrafficAlert <- struct{}{}: require.Failf(t, "", "TrafficAlert should block after first slot used; Check #%d", i) default: - require.WithinDuration(t, time.Now(), thenish.Add(time.Duration(i)*trafficCheckInterval), trafficCheckInterval) + require.WithinDuration(t, time.Now(), thenish.Add(time.Duration(i)*trafficCheckInterval), trafficCheckInterval*2) } require.Eventuallyf(t, func() bool { return len(ws.TrafficAlert) == 0 }, 2*trafficCheckInterval, time.Millisecond, "trafficAlert should be drained; Check #%d", i) From 9e53313fbe0b8a5485fc2a47e37d97b9fbcf3cc2 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 25 Sep 2024 12:14:46 +1000 Subject: [PATCH 082/138] stream: fix intermittant test failures while testing routines and remove code that is not needed. --- exchanges/stream/websocket.go | 366 ++++++++++++---------------- exchanges/stream/websocket_test.go | 222 ++++++++--------- exchanges/stream/websocket_types.go | 2 - 3 files changed, 257 insertions(+), 333 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 2319b7153e7..388ad94cefb 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -6,6 +6,7 @@ import ( "fmt" "net/url" "slices" + "sync" "time" "github.com/thrasher-corp/gocryptotrader/common" @@ -71,8 +72,6 @@ var ( var globalReporter Reporter -const trafficCheckInterval = 100 * time.Millisecond - // SetupGlobalReporter sets a reporter interface to be used // for all exchange requests func SetupGlobalReporter(r Reporter) { @@ -336,10 +335,17 @@ func (w *Websocket) connect() error { } w.subscriptions.Clear() - w.dataMonitor() - w.trafficMonitor() w.setState(connectingState) + w.Wg.Add(2) + go w.monitorFrame(&w.Wg, w.monitorData) + go w.monitorFrame(&w.Wg, w.monitorTraffic) + + if w.connectionMonitorRunning.CompareAndSwap(false, true) { + // This oversees all connections and does not need to be part of wait group management. + go w.monitorFrame(nil, w.monitorConnection) + } + if !w.useMultiConnectionManagement { if w.connector == nil { return fmt.Errorf("%v %w", w.exchangeName, errNoConnectFunc) @@ -351,13 +357,6 @@ func (w *Websocket) connect() error { } w.setState(connectedState) - if !w.IsConnectionMonitorRunning() { - err = w.connectionMonitor() - if err != nil { - log.Errorf(log.WebsocketMgr, "%s cannot start websocket connection monitor %v", w.GetName(), err) - } - } - subs, err := w.GenerateSubs() // regenerate state on new connection if err != nil { return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) @@ -475,14 +474,6 @@ func (w *Websocket) connect() error { // All subscriptions have been sent and stored. All data received is being // handled by the appropriate data handler. w.setState(connectedState) - - if !w.IsConnectionMonitorRunning() { - err := w.connectionMonitor() - if err != nil { - log.Errorf(log.WebsocketMgr, "%s cannot start websocket connection monitor %v", w.GetName(), err) - } - } - return nil } @@ -507,109 +498,6 @@ func (w *Websocket) Enable() error { return w.Connect() } -// dataMonitor monitors job throughput and logs if there is a back log of data -func (w *Websocket) dataMonitor() { - if w.IsDataMonitorRunning() { - return - } - w.setDataMonitorRunning(true) - w.Wg.Add(1) - - go func() { - defer func() { - w.setDataMonitorRunning(false) - w.Wg.Done() - }() - dropped := 0 - for { - select { - case <-w.ShutdownC: - return - case d := <-w.DataHandler: - select { - case w.ToRoutine <- d: - if dropped != 0 { - log.Infof(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer recovered; %d messages were dropped", w.exchangeName, dropped) - dropped = 0 - } - default: - if dropped == 0 { - // If this becomes prone to flapping we could drain the buffer, but that's extreme and we'd like to avoid it if possible - log.Warnf(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer full; dropping messages", w.exchangeName) - } - dropped++ - } - } - } - }() -} - -// connectionMonitor ensures that the WS keeps connecting -func (w *Websocket) connectionMonitor() error { - if w.checkAndSetMonitorRunning() { - return errAlreadyRunning - } - delay := w.connectionMonitorDelay - - go func() { - timer := time.NewTimer(delay) - for { - if w.verbose { - log.Debugf(log.WebsocketMgr, "%v websocket: running connection monitor cycle", w.exchangeName) - } - if !w.IsEnabled() { - if w.verbose { - log.Debugf(log.WebsocketMgr, "%v websocket: connectionMonitor - websocket disabled, shutting down", w.exchangeName) - } - if w.IsConnected() { - if err := w.Shutdown(); err != nil { - log.Errorln(log.WebsocketMgr, err) - } - } - if w.verbose { - log.Debugf(log.WebsocketMgr, "%v websocket: connection monitor exiting", w.exchangeName) - } - timer.Stop() - w.setConnectionMonitorRunning(false) - return - } - select { - case err := <-w.ReadMessageErrors: - if errors.Is(err, errConnectionFault) { - log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err) - if w.IsConnected() { - if shutdownErr := w.Shutdown(); shutdownErr != nil { - log.Errorf(log.WebsocketMgr, "%v websocket: connectionMonitor shutdown err: %s", w.exchangeName, shutdownErr) - } - } - } - // Speedier reconnection, instead of waiting for the next cycle. - if w.IsEnabled() && (!w.IsConnected() && !w.IsConnecting()) { - if connectErr := w.Connect(); connectErr != nil { - log.Errorln(log.WebsocketMgr, connectErr) - } - } - w.DataHandler <- err // hand over the error to the data handler (shutdown and reconnection is priority) - case <-timer.C: - if !w.IsConnecting() && !w.IsConnected() { - err := w.Connect() - if err != nil { - log.Errorln(log.WebsocketMgr, err) - } - } - if !timer.Stop() { - select { - case <-timer.C: - default: - } - } - timer.Reset(delay) - } - } - }() - return nil -} - // Shutdown attempts to shut down a websocket connection and associated routines // by using a package defined shutdown function func (w *Websocket) Shutdown() error { @@ -773,63 +661,6 @@ func (w *Websocket) FlushChannels() error { return nil } -// trafficMonitor waits trafficCheckInterval before checking for a trafficAlert -// 1 slot buffer means that connection will only write to trafficAlert once per trafficCheckInterval to avoid read/write flood in high traffic -// Otherwise we Shutdown the connection after trafficTimeout, unless it's connecting. connectionMonitor is responsible for Connecting again -func (w *Websocket) trafficMonitor() { - if w.IsTrafficMonitorRunning() { - return - } - w.setTrafficMonitorRunning(true) - w.Wg.Add(1) - - go func() { - t := time.NewTimer(w.trafficTimeout) - for { - select { - case <-w.ShutdownC: - if w.verbose { - log.Debugf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown message received", w.exchangeName) - } - t.Stop() - w.setTrafficMonitorRunning(false) - w.Wg.Done() - return - case <-time.After(trafficCheckInterval): - if signalReceived(w.TrafficAlert) { - t.Reset(w.trafficTimeout) - } - case <-t.C: - if w.IsConnecting() || signalReceived(w.TrafficAlert) { - t.Reset(w.trafficTimeout) - break - } - if w.verbose { - log.Warnf(log.WebsocketMgr, "%v websocket: has not received a traffic alert in %v. Reconnecting", w.exchangeName, w.trafficTimeout) - } - w.setTrafficMonitorRunning(false) // Cannot defer lest Connect is called after Shutdown but before deferred call - w.Wg.Done() // Without this the w.Shutdown() call below will deadlock - if w.IsConnected() { - if err := w.Shutdown(); err != nil { - log.Errorf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown err: %s", w.exchangeName, err) - } - } - return - } - } - }() -} - -// signalReceived checks if a signal has been received, this also clears the signal. -func signalReceived(ch chan struct{}) bool { - select { - case <-ch: - return true - default: - return false - } -} - func (w *Websocket) setState(s uint32) { w.state.Store(s) } @@ -858,37 +689,6 @@ func (w *Websocket) IsEnabled() bool { return w.enabled.Load() } -func (w *Websocket) setTrafficMonitorRunning(b bool) { - w.trafficMonitorRunning.Store(b) -} - -// IsTrafficMonitorRunning returns status of the traffic monitor -func (w *Websocket) IsTrafficMonitorRunning() bool { - return w.trafficMonitorRunning.Load() -} - -func (w *Websocket) checkAndSetMonitorRunning() (alreadyRunning bool) { - return !w.connectionMonitorRunning.CompareAndSwap(false, true) -} - -func (w *Websocket) setConnectionMonitorRunning(b bool) { - w.connectionMonitorRunning.Store(b) -} - -// IsConnectionMonitorRunning returns status of connection monitor -func (w *Websocket) IsConnectionMonitorRunning() bool { - return w.connectionMonitorRunning.Load() -} - -func (w *Websocket) setDataMonitorRunning(b bool) { - w.dataMonitorRunning.Store(b) -} - -// IsDataMonitorRunning returns status of data monitor -func (w *Websocket) IsDataMonitorRunning() bool { - return w.dataMonitorRunning.Load() -} - // CanUseAuthenticatedWebsocketForWrapper Handles a common check to // verify whether a wrapper can use an authenticated websocket endpoint func (w *Websocket) CanUseAuthenticatedWebsocketForWrapper() bool { @@ -1297,3 +1097,149 @@ func drain(ch <-chan error) { } } } + +// ClosureFrame is a closure function that wraps monitoring variables with observer, if the return is true the frame will exit +type ClosureFrame func() func() bool + +// monitorFrame monitors a specific websocket componant or critical system. It will exit if the observer returns true +// This is used for monitoring data throughput, connection status and other critical websocket components. The waitgroup +// is optional and is used to signal when the monitor has finished. +func (w *Websocket) monitorFrame(wg *sync.WaitGroup, fn ClosureFrame) { + if wg != nil { + defer w.Wg.Done() + } + observe := fn() + for { + if observe() { + return + } + } +} + +// monitorData monitors data throughput and logs if there is a back log of data +func (w *Websocket) monitorData() func() bool { + dropped := 0 + return func() bool { return w.observeData(&dropped) } +} + +// observeData observes data throughput and logs if there is a back log of data +func (w *Websocket) observeData(dropped *int) (exit bool) { + select { + case <-w.ShutdownC: + return true + case d := <-w.DataHandler: + select { + case w.ToRoutine <- d: + if *dropped != 0 { + log.Infof(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer recovered; %d messages were dropped", w.exchangeName, dropped) + *dropped = 0 + } + default: + if *dropped == 0 { + // If this becomes prone to flapping we could drain the buffer, but that's extreme and we'd like to avoid it if possible + log.Warnf(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer full; dropping messages", w.exchangeName) + } + *dropped++ + } + return false + } +} + +// monitorConnection monitors the connection and attempts to reconnect if the connection is lost +func (w *Websocket) monitorConnection() func() bool { + timer := time.NewTimer(w.connectionMonitorDelay) + return func() bool { return w.observeConnection(timer) } +} + +// observeConnection observes the connection and attempts to reconnect if the connection is lost +func (w *Websocket) observeConnection(t *time.Timer) (exit bool) { + select { + case err := <-w.ReadMessageErrors: + if errors.Is(err, errConnectionFault) { + log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err) + if w.IsConnected() { + if shutdownErr := w.Shutdown(); shutdownErr != nil { + log.Errorf(log.WebsocketMgr, "%v websocket: connectionMonitor shutdown err: %s", w.exchangeName, shutdownErr) + } + } + } + // Speedier reconnection, instead of waiting for the next cycle. + if w.IsEnabled() && (!w.IsConnected() && !w.IsConnecting()) { + if connectErr := w.Connect(); connectErr != nil { + log.Errorln(log.WebsocketMgr, connectErr) + } + } + w.DataHandler <- err // hand over the error to the data handler (shutdown and reconnection is priority) + case <-t.C: + if w.verbose { + log.Debugf(log.WebsocketMgr, "%v websocket: running connection monitor cycle", w.exchangeName) + } + if !w.IsEnabled() { + if w.verbose { + log.Debugf(log.WebsocketMgr, "%v websocket: connectionMonitor - websocket disabled, shutting down", w.exchangeName) + } + if w.IsConnected() { + if err := w.Shutdown(); err != nil { + log.Errorln(log.WebsocketMgr, err) + } + } + if w.verbose { + log.Debugf(log.WebsocketMgr, "%v websocket: connection monitor exiting", w.exchangeName) + } + t.Stop() + w.connectionMonitorRunning.Store(false) + return true + } + if !w.IsConnecting() && !w.IsConnected() { + err := w.Connect() + if err != nil { + log.Errorln(log.WebsocketMgr, err) + } + } + t.Reset(w.connectionMonitorDelay) + } + return false +} + +// monitorTraffic monitors to see if there has been traffic within the trafficTimeout time window. If there is no traffic +// the connection is shutdown and will be reconnected by the connectionMonitor routine. +func (w *Websocket) monitorTraffic() func() bool { + timer := time.NewTimer(w.trafficTimeout) + return func() bool { return w.observeTraffic(timer) } +} + +func (w *Websocket) observeTraffic(t *time.Timer) bool { + select { + case <-w.ShutdownC: + if w.verbose { + log.Debugf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown message received", w.exchangeName) + } + case <-t.C: + if w.IsConnecting() || signalReceived(w.TrafficAlert) { + t.Reset(w.trafficTimeout) + return false + } + if w.verbose { + log.Warnf(log.WebsocketMgr, "%v websocket: has not received a traffic alert in %v. Reconnecting", w.exchangeName, w.trafficTimeout) + } + if w.IsConnected() { + w.Wg.Done() // Without this the w.Shutdown() call below will deadlock + if err := w.Shutdown(); err != nil { + log.Errorf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown err: %s", w.exchangeName, err) + } + w.Wg.Add(1) + } + } + t.Stop() + return true +} + +// signalReceived checks if a signal has been received, this also clears the signal. +func signalReceived(ch chan struct{}) bool { + select { + case <-ch: + return true + default: + return false + } +} diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 2280716865b..35463ece0c1 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -165,102 +165,6 @@ func TestSetup(t *testing.T) { assert.NoError(t, err, "Setup should not error") } -// TestTrafficMonitorTrafficAlerts ensures multiple traffic alerts work and only process one trafficAlert per interval -// ensures shutdown works after traffic alerts -func TestTrafficMonitorTrafficAlerts(t *testing.T) { - t.Parallel() - ws := NewWebsocket() - err := ws.Setup(defaultSetup) - require.NoError(t, err, "Setup must not error") - - ws.trafficTimeout = 200 * time.Millisecond - ws.state.Store(connectedState) - - thenish := time.Now() - ws.trafficMonitor() - - assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") - require.Equal(t, connectedState, ws.state.Load(), "websocket must be connected") - - for i := range 6 { // Timeout will happen at 200ms so we want 6 * 50ms checks to pass - select { - case ws.TrafficAlert <- struct{}{}: - require.WithinDuration(t, time.Now(), thenish.Add(time.Duration(i)*trafficCheckInterval), trafficCheckInterval*2) - default: - require.Failf(t, "", "TrafficAlert should not block; Check #%d", i) - } - - select { - case ws.TrafficAlert <- struct{}{}: - require.Failf(t, "", "TrafficAlert should block after first slot used; Check #%d", i) - default: - require.WithinDuration(t, time.Now(), thenish.Add(time.Duration(i)*trafficCheckInterval), trafficCheckInterval*2) - } - - require.Eventuallyf(t, func() bool { return len(ws.TrafficAlert) == 0 }, 2*trafficCheckInterval, time.Millisecond, "trafficAlert should be drained; Check #%d", i) - assert.Truef(t, ws.IsConnected(), "state should still be connected; Check #%d", i) - } - - require.EventuallyWithT(t, func(c *assert.CollectT) { - assert.Equal(c, disconnectedState, ws.state.Load(), "websocket must be disconnected") - assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") - }, 2*ws.trafficTimeout, time.Millisecond, "trafficTimeout should trigger a shutdown once we stop feeding trafficAlerts") -} - -// TestTrafficMonitorConnecting ensures connecting status doesn't trigger shutdown -func TestTrafficMonitorConnecting(t *testing.T) { - t.Parallel() - ws := NewWebsocket() - err := ws.Setup(defaultSetup) - require.NoError(t, err, "Setup must not error") - - ws.state.Store(connectingState) - ws.trafficTimeout = 50 * time.Millisecond - ws.trafficMonitor() - require.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") - require.Equal(t, connectingState, ws.state.Load(), "websocket must be connecting") - <-time.After(4 * ws.trafficTimeout) - require.Equal(t, connectingState, ws.state.Load(), "websocket must still be connecting after several checks") - ws.state.Store(connectedState) - require.EventuallyWithT(t, func(c *assert.CollectT) { - assert.Equal(c, disconnectedState, ws.state.Load(), "websocket must be disconnected") - assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") - }, 8*ws.trafficTimeout, 10*time.Millisecond, "trafficTimeout should trigger a shutdown after connecting status changes") -} - -// TestTrafficMonitorShutdown ensures shutdown is processed and waitgroup is cleared -func TestTrafficMonitorShutdown(t *testing.T) { - t.Parallel() - ws := NewWebsocket() - err := ws.Setup(defaultSetup) - require.NoError(t, err, "Setup must not error") - - ws.state.Store(connectedState) - ws.trafficTimeout = time.Minute - ws.verbose = true - close(ws.TrafficAlert) // Mocks channel traffic signal. - ws.trafficMonitor() - assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") - - wgReady := make(chan bool) - go func() { ws.Wg.Wait(); close(wgReady) }() - select { - case <-wgReady: - require.Fail(t, "WaitGroup should be blocking still") - case <-time.After(trafficCheckInterval): - } - - close(ws.ShutdownC) - - <-time.After(4 * trafficCheckInterval) - assert.False(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be shutdown") - select { - case <-wgReady: - default: - require.Fail(t, "WaitGroup should be freed now") - } -} - func TestConnectionMessageErrors(t *testing.T) { t.Parallel() var wsWrong = &Websocket{} @@ -673,31 +577,6 @@ func TestSuccessfulSubscriptions(t *testing.T) { assert.ErrorIs(t, w.RemoveSubscriptions(nil, c), common.ErrNilPointer, "Should error correctly when nil websocket") } -// TestConnectionMonitorNoConnection logic test -func TestConnectionMonitorNoConnection(t *testing.T) { - t.Parallel() - ws := NewWebsocket() - ws.connectionMonitorDelay = 500 - ws.exchangeName = "hello" - ws.setEnabled(true) - err := ws.connectionMonitor() - require.NoError(t, err, "connectionMonitor must not error") - assert.True(t, ws.IsConnectionMonitorRunning(), "IsConnectionMonitorRunning should return true") - err = ws.connectionMonitor() - assert.ErrorIs(t, err, errAlreadyRunning, "connectionMonitor should error correctly") - - ws.setState(connectedState) - ws.ReadMessageErrors <- errConnectionFault - select { - case data := <-ws.DataHandler: - err, ok := data.(error) - require.True(t, ok, "DataHandler should return an error") - require.ErrorIs(t, err, errConnectionFault, "DataHandler should return the correct error") - case <-time.After(2 * time.Second): - t.Fatal("DataHandler should return an error") - } -} - // TestGetSubscription logic test func TestGetSubscription(t *testing.T) { t.Parallel() @@ -1495,3 +1374,104 @@ func TestDrain(t *testing.T) { drain(ch) require.Empty(t, ch, "Drain should empty the channel") } + +func TestMonitorFrame(t *testing.T) { + t.Parallel() + ws := Websocket{} + require.Panics(t, func() { ws.monitorFrame(nil, nil) }, "monitorFrame must panic on nil frame") + require.Panics(t, func() { ws.monitorFrame(nil, func() func() bool { return nil }) }, "monitorFrame must panic on nil function") + ws.Wg.Add(1) + ws.monitorFrame(&ws.Wg, func() func() bool { return func() bool { return true } }) + ws.Wg.Wait() +} + +func TestMonitorData(t *testing.T) { + t.Parallel() + ws := Websocket{ShutdownC: make(chan struct{}), DataHandler: make(chan interface{}, 10)} + // Handle shutdown signal + close(ws.ShutdownC) + require.True(t, ws.observeData(nil)) + ws.ShutdownC = make(chan struct{}) + // Handle blockage of ToRoutine + go func() { ws.DataHandler <- nil }() + var dropped int + require.False(t, ws.observeData(&dropped)) + require.Equal(t, 1, dropped) + // Handle reinstate of ToRoutine functionality which will reset dropped counter + ws.ToRoutine = make(chan interface{}, 10) + go func() { ws.DataHandler <- nil }() + require.False(t, ws.observeData(&dropped)) + require.Empty(t, dropped) + // Handle outter closure shell + innerShell := ws.monitorData() + go func() { ws.DataHandler <- nil }() + require.False(t, innerShell()) + // Handle shutdown signal + close(ws.ShutdownC) + require.True(t, innerShell()) +} + +func TestMonitorConnection(t *testing.T) { + t.Parallel() + ws := Websocket{verbose: true, ReadMessageErrors: make(chan error, 1), ShutdownC: make(chan struct{})} + // Handle timer expired and websocket disabled, shutdown everything. + timer := time.NewTimer(0) + ws.setState(connectedState) + ws.connectionMonitorRunning.Store(true) + require.True(t, ws.observeConnection(timer)) + require.False(t, ws.connectionMonitorRunning.Load()) + require.Equal(t, disconnectedState, ws.state.Load()) + // Handle timer expired and everything is great, reset the timer. + ws.setEnabled(true) + ws.setState(connectedState) + ws.connectionMonitorRunning.Store(true) + timer = time.NewTimer(0) + require.False(t, ws.observeConnection(timer)) // Not shutting down + // Handle timer expired and for reason its not connected, so lets happily connect again. + ws.setState(disconnectedState) + require.False(t, ws.observeConnection(timer)) // Connect is intentionally erroring + // Handle error from a connection which will then trigger a reconnect + ws.setState(connectedState) + ws.DataHandler = make(chan interface{}, 1) + ws.ReadMessageErrors <- errConnectionFault + timer = time.NewTimer(time.Second) + require.False(t, ws.observeConnection(timer)) + payload := <-ws.DataHandler + err, ok := payload.(error) + require.True(t, ok) + require.ErrorIs(t, err, errConnectionFault) + // Handle outta closure shell + innerShell := ws.monitorConnection() + ws.setState(connectedState) + ws.ReadMessageErrors <- errConnectionFault + require.False(t, innerShell()) +} + +func TestMonitorTraffic(t *testing.T) { + t.Parallel() + ws := Websocket{verbose: true, ShutdownC: make(chan struct{}), TrafficAlert: make(chan struct{}, 1)} + ws.Wg.Add(1) + // Handle external shutdown signal + timer := time.NewTimer(time.Second) + close(ws.ShutdownC) + require.True(t, ws.observeTraffic(timer)) + // Handle timer expired but system is connecting, so reset the timer + ws.ShutdownC = make(chan struct{}) + ws.setState(connectingState) + timer = time.NewTimer(0) + require.False(t, ws.observeTraffic(timer)) + // Handle timer expired and system is connected and has traffic within time window + ws.setState(connectedState) + timer = time.NewTimer(0) + ws.TrafficAlert <- struct{}{} + require.False(t, ws.observeTraffic(timer)) + // Handle timer expired and system is connected but no traffic within time window, causes shutdown to occur. + timer = time.NewTimer(0) + require.True(t, ws.observeTraffic(timer)) + require.Equal(t, disconnectedState, ws.state.Load()) + // Handle outter closure shell + innerShell := ws.monitorTraffic() + ws.setState(connectedState) + ws.TrafficAlert <- struct{}{} + require.False(t, innerShell()) +} diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index e903e997157..27a5c81963f 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -38,8 +38,6 @@ type Websocket struct { state atomic.Uint32 verbose bool connectionMonitorRunning atomic.Bool - trafficMonitorRunning atomic.Bool - dataMonitorRunning atomic.Bool trafficTimeout time.Duration connectionMonitorDelay time.Duration proxyAddr string From 10230b07ae8342118c3e68f11561442347505791 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 25 Sep 2024 12:17:52 +1000 Subject: [PATCH 083/138] spells --- exchanges/stream/websocket.go | 2 +- exchanges/stream/websocket_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 388ad94cefb..8ddb997b025 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -1101,7 +1101,7 @@ func drain(ch <-chan error) { // ClosureFrame is a closure function that wraps monitoring variables with observer, if the return is true the frame will exit type ClosureFrame func() func() bool -// monitorFrame monitors a specific websocket componant or critical system. It will exit if the observer returns true +// monitorFrame monitors a specific websocket component or critical system. It will exit if the observer returns true // This is used for monitoring data throughput, connection status and other critical websocket components. The waitgroup // is optional and is used to signal when the monitor has finished. func (w *Websocket) monitorFrame(wg *sync.WaitGroup, fn ClosureFrame) { diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 35463ece0c1..01c5a2a7f10 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -1402,7 +1402,7 @@ func TestMonitorData(t *testing.T) { go func() { ws.DataHandler <- nil }() require.False(t, ws.observeData(&dropped)) require.Empty(t, dropped) - // Handle outter closure shell + // Handle outer closure shell innerShell := ws.monitorData() go func() { ws.DataHandler <- nil }() require.False(t, innerShell()) @@ -1469,7 +1469,7 @@ func TestMonitorTraffic(t *testing.T) { timer = time.NewTimer(0) require.True(t, ws.observeTraffic(timer)) require.Equal(t, disconnectedState, ws.state.Load()) - // Handle outter closure shell + // Handle outer closure shell innerShell := ws.monitorTraffic() ws.setState(connectedState) ws.TrafficAlert <- struct{}{} From 4e02c4afe80da7270d753d318104bbb1ce947f39 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 25 Sep 2024 12:38:13 +1000 Subject: [PATCH 084/138] cant do what I did --- exchanges/stream/websocket.go | 10 +++++----- exchanges/stream/websocket_test.go | 4 +++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 8ddb997b025..5451bb882f0 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -1223,11 +1223,11 @@ func (w *Websocket) observeTraffic(t *time.Timer) bool { log.Warnf(log.WebsocketMgr, "%v websocket: has not received a traffic alert in %v. Reconnecting", w.exchangeName, w.trafficTimeout) } if w.IsConnected() { - w.Wg.Done() // Without this the w.Shutdown() call below will deadlock - if err := w.Shutdown(); err != nil { - log.Errorf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown err: %s", w.exchangeName, err) - } - w.Wg.Add(1) + go func() { // Without this the w.Shutdown() call below will deadlock + if err := w.Shutdown(); err != nil { + log.Errorf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown err: %s", w.exchangeName, err) + } + }() } } t.Stop() diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 01c5a2a7f10..8f162ed7681 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -1468,9 +1468,11 @@ func TestMonitorTraffic(t *testing.T) { // Handle timer expired and system is connected but no traffic within time window, causes shutdown to occur. timer = time.NewTimer(0) require.True(t, ws.observeTraffic(timer)) - require.Equal(t, disconnectedState, ws.state.Load()) + // Shutdown is done in a routine, so we need to wait for it to finish + require.Eventually(t, func() bool { return disconnectedState == ws.state.Load() }, time.Second, time.Millisecond) // Handle outer closure shell innerShell := ws.monitorTraffic() + ws.ShutdownC = make(chan struct{}) ws.setState(connectedState) ws.TrafficAlert <- struct{}{} require.False(t, innerShell()) From 7be80a0ed32891e353c522794aed4a7e088c7fb7 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 25 Sep 2024 13:01:45 +1000 Subject: [PATCH 085/138] protect race due to routine. --- exchanges/stream/websocket_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 8f162ed7681..fcc0bb66c8e 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -1468,11 +1468,14 @@ func TestMonitorTraffic(t *testing.T) { // Handle timer expired and system is connected but no traffic within time window, causes shutdown to occur. timer = time.NewTimer(0) require.True(t, ws.observeTraffic(timer)) + ws.Wg.Done() // Shutdown is done in a routine, so we need to wait for it to finish require.Eventually(t, func() bool { return disconnectedState == ws.state.Load() }, time.Second, time.Millisecond) // Handle outer closure shell innerShell := ws.monitorTraffic() + ws.m.Lock() ws.ShutdownC = make(chan struct{}) + ws.m.Unlock() ws.setState(connectedState) ws.TrafficAlert <- struct{}{} require.False(t, innerShell()) From b51cf2f70c653efcb8db5580c09abff2825f4d13 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 25 Sep 2024 13:13:31 +1000 Subject: [PATCH 086/138] update testURL --- exchanges/stream/websocket_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index fcc0bb66c8e..90b80b84cbe 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -28,7 +28,7 @@ import ( ) const ( - websocketTestURL = "wss://www.bitmex.com/realtime" + websocketTestURL = "wss://ws.bitmex.com/realtime" useProxyTests = false // Disabled by default. Freely available proxy servers that work all the time are difficult to find proxyURL = "http://212.186.171.4:80" // Replace with a usable proxy server ) From ff7ae03e109662ed5dac0e98e12dad34ce640ed9 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 25 Sep 2024 13:41:15 +1000 Subject: [PATCH 087/138] use mock websocket connection instead of test URL's --- exchanges/stream/websocket_test.go | 121 ++++++++++++++--------------- 1 file changed, 57 insertions(+), 64 deletions(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 90b80b84cbe..fb83c1b2986 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -37,8 +37,6 @@ var ( errDastardlyReason = errors.New("some dastardly reason") ) -var dialer websocket.Dialer - type testStruct struct { Error error WC WebsocketConnection @@ -616,17 +614,22 @@ func TestSetCanUseAuthenticatedEndpoints(t *testing.T) { // TestDial logic test func TestDial(t *testing.T) { t.Parallel() + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) + defer mock.Close() + var testCases = []testStruct{ - {Error: nil, + { WC: WebsocketConnection{ ExchangeName: "test1", Verbose: true, - URL: websocketTestURL, + URL: "ws" + mock.URL[len("http"):] + "/ws", RateLimit: request.NewWeightedRateLimitByDuration(10 * time.Millisecond), ResponseMaxLimit: 7000000000, }, }, - {Error: errors.New(" Error: malformed ws or wss URL"), + { + Error: errors.New(" Error: malformed ws or wss URL"), WC: WebsocketConnection{ ExchangeName: "test2", Verbose: true, @@ -634,26 +637,24 @@ func TestDial(t *testing.T) { ResponseMaxLimit: 7000000000, }, }, - {Error: nil, + { WC: WebsocketConnection{ ExchangeName: "test3", Verbose: true, - URL: websocketTestURL, + URL: "ws" + mock.URL[len("http"):] + "/ws", ProxyURL: proxyURL, ResponseMaxLimit: 7000000000, }, }, } for i := range testCases { - testData := &testCases[i] - t.Run(testData.WC.ExchangeName, func(t *testing.T) { - t.Parallel() - if testData.WC.ProxyURL != "" && !useProxyTests { + t.Run(testCases[i].WC.ExchangeName, func(t *testing.T) { + if testCases[i].WC.ProxyURL != "" && !useProxyTests { t.Skip("Proxy testing not enabled, skipping") } - err := testData.WC.Dial(&dialer, http.Header{}) + err := testCases[i].WC.Dial(&websocket.Dialer{}, http.Header{}) if err != nil { - if testData.Error != nil && strings.Contains(err.Error(), testData.Error.Error()) { + if testCases[i].Error != nil && strings.Contains(err.Error(), testCases[i].Error.Error()) { return } t.Fatal(err) @@ -665,16 +666,22 @@ func TestDial(t *testing.T) { // TestSendMessage logic test func TestSendMessage(t *testing.T) { t.Parallel() + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) + defer mock.Close() + var testCases = []testStruct{ - {Error: nil, WC: WebsocketConnection{ - ExchangeName: "test1", - Verbose: true, - URL: websocketTestURL, - RateLimit: request.NewWeightedRateLimitByDuration(10 * time.Millisecond), - ResponseMaxLimit: 7000000000, - }, + { + WC: WebsocketConnection{ + ExchangeName: "test1", + Verbose: true, + URL: "ws" + mock.URL[len("http"):] + "/ws", + RateLimit: request.NewWeightedRateLimitByDuration(10 * time.Millisecond), + ResponseMaxLimit: 7000000000, + }, }, - {Error: errors.New(" Error: malformed ws or wss URL"), + { + Error: errors.New(" Error: malformed ws or wss URL"), WC: WebsocketConnection{ ExchangeName: "test2", Verbose: true, @@ -682,38 +689,32 @@ func TestSendMessage(t *testing.T) { ResponseMaxLimit: 7000000000, }, }, - {Error: nil, + { WC: WebsocketConnection{ ExchangeName: "test3", Verbose: true, - URL: websocketTestURL, + URL: "ws" + mock.URL[len("http"):] + "/ws", ProxyURL: proxyURL, ResponseMaxLimit: 7000000000, }, }, } - for i := range testCases { - testData := &testCases[i] - t.Run(testData.WC.ExchangeName, func(t *testing.T) { - t.Parallel() - if testData.WC.ProxyURL != "" && !useProxyTests { + for x := range testCases { + t.Run(testCases[x].WC.ExchangeName, func(t *testing.T) { + if testCases[x].WC.ProxyURL != "" && !useProxyTests { t.Skip("Proxy testing not enabled, skipping") } - err := testData.WC.Dial(&dialer, http.Header{}) + err := testCases[x].WC.Dial(&websocket.Dialer{}, http.Header{}) if err != nil { - if testData.Error != nil && strings.Contains(err.Error(), testData.Error.Error()) { + if testCases[x].Error != nil && strings.Contains(err.Error(), testCases[x].Error.Error()) { return } t.Fatal(err) } - err = testData.WC.SendJSONMessage(context.Background(), request.Unset, Ping) - if err != nil { - t.Error(err) - } - err = testData.WC.SendRawMessage(context.Background(), request.Unset, websocket.TextMessage, []byte(Ping)) - if err != nil { - t.Error(err) - } + err = testCases[x].WC.SendJSONMessage(context.Background(), request.Unset, Ping) + require.NoError(t, err) + err = testCases[x].WC.SendRawMessage(context.Background(), request.Unset, websocket.TextMessage, []byte(Ping)) + require.NoError(t, err) }) } } @@ -730,7 +731,7 @@ func TestSendMessageReturnResponse(t *testing.T) { t.Skip("Proxy testing not enabled, skipping") } - err := wc.Dial(&dialer, http.Header{}) + err := wc.Dial(&websocket.Dialer{}, http.Header{}) if err != nil { t.Fatal(err) } @@ -816,7 +817,7 @@ func TestSetupPingHandler(t *testing.T) { t.Skip("Proxy testing not enabled, skipping") } wc.shutdown = make(chan struct{}) - err := wc.Dial(&dialer, http.Header{}) + err := wc.Dial(&websocket.Dialer{}, http.Header{}) if err != nil { t.Fatal(err) } @@ -832,7 +833,7 @@ func TestSetupPingHandler(t *testing.T) { t.Error(err) } - err = wc.Dial(&dialer, http.Header{}) + err = wc.Dial(&websocket.Dialer{}, http.Header{}) if err != nil { t.Fatal(err) } @@ -1253,13 +1254,17 @@ func TestWebsocketConnectionShutdown(t *testing.T) { // TestLatency logic test func TestLatency(t *testing.T) { t.Parallel() + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) + defer mock.Close() + r := &reporter{} exch := "Kraken" wc := &WebsocketConnection{ ExchangeName: exch, Verbose: true, - URL: "wss://ws.kraken.com", - ResponseMaxLimit: time.Second * 5, + URL: "ws" + mock.URL[len("http"):] + "/ws", + ResponseMaxLimit: time.Second * 1, Match: NewMatch(), Reporter: r, } @@ -1267,34 +1272,22 @@ func TestLatency(t *testing.T) { t.Skip("Proxy testing not enabled, skipping") } - err := wc.Dial(&dialer, http.Header{}) - if err != nil { - t.Fatal(err) - } + err := wc.Dial(&websocket.Dialer{}, http.Header{}) + require.NoError(t, err) go readMessages(t, wc) req := testRequest{ - Event: "subscribe", - Pairs: []string{currency.NewPairWithDelimiter("XBT", "USD", "/").String()}, - Subscription: testRequestData{ - Name: "ticker", - }, - RequestID: wc.GenerateMessageID(false), + Event: "subscribe", + Pairs: []string{currency.NewPairWithDelimiter("XBT", "USD", "/").String()}, + Subscription: testRequestData{Name: "ticker"}, + RequestID: wc.GenerateMessageID(false), } _, err = wc.SendMessageReturnResponse(context.Background(), request.Unset, req.RequestID, req) - if err != nil { - t.Error(err) - } - - if r.t == 0 { - t.Error("expected a nonzero duration, got zero") - } - - if r.name != exch { - t.Errorf("expected %v, got %v", exch, r.name) - } + require.NoError(t, err) + require.NotEmpty(t, r.t, "Latency should have a duration") + require.Equal(t, exch, r.name, "Latency should have the correct exchange name") } func TestCheckSubscriptions(t *testing.T) { From 1040153eae69a514c4842ec36dbd134fa3ef72a1 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 25 Sep 2024 13:50:12 +1000 Subject: [PATCH 088/138] linter: fix --- exchanges/stream/websocket_test.go | 52 +++++++++++++++--------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index fb83c1b2986..47efd5ccf2d 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -647,19 +647,19 @@ func TestDial(t *testing.T) { }, }, } + // Mock server rejects parallel connections for i := range testCases { - t.Run(testCases[i].WC.ExchangeName, func(t *testing.T) { - if testCases[i].WC.ProxyURL != "" && !useProxyTests { - t.Skip("Proxy testing not enabled, skipping") - } - err := testCases[i].WC.Dial(&websocket.Dialer{}, http.Header{}) - if err != nil { - if testCases[i].Error != nil && strings.Contains(err.Error(), testCases[i].Error.Error()) { - return - } - t.Fatal(err) + if testCases[i].WC.ProxyURL != "" && !useProxyTests { + t.Log("Proxy testing not enabled, skipping") + continue + } + err := testCases[i].WC.Dial(&websocket.Dialer{}, http.Header{}) + if err != nil { + if testCases[i].Error != nil && strings.Contains(err.Error(), testCases[i].Error.Error()) { + return } - }) + t.Fatal(err) + } } } @@ -699,23 +699,23 @@ func TestSendMessage(t *testing.T) { }, }, } + // Mock server rejects parallel connections for x := range testCases { - t.Run(testCases[x].WC.ExchangeName, func(t *testing.T) { - if testCases[x].WC.ProxyURL != "" && !useProxyTests { - t.Skip("Proxy testing not enabled, skipping") - } - err := testCases[x].WC.Dial(&websocket.Dialer{}, http.Header{}) - if err != nil { - if testCases[x].Error != nil && strings.Contains(err.Error(), testCases[x].Error.Error()) { - return - } - t.Fatal(err) + if testCases[x].WC.ProxyURL != "" && !useProxyTests { + t.Log("Proxy testing not enabled, skipping") + continue + } + err := testCases[x].WC.Dial(&websocket.Dialer{}, http.Header{}) + if err != nil { + if testCases[x].Error != nil && strings.Contains(err.Error(), testCases[x].Error.Error()) { + return } - err = testCases[x].WC.SendJSONMessage(context.Background(), request.Unset, Ping) - require.NoError(t, err) - err = testCases[x].WC.SendRawMessage(context.Background(), request.Unset, websocket.TextMessage, []byte(Ping)) - require.NoError(t, err) - }) + t.Fatal(err) + } + err = testCases[x].WC.SendJSONMessage(context.Background(), request.Unset, Ping) + require.NoError(t, err) + err = testCases[x].WC.SendRawMessage(context.Background(), request.Unset, websocket.TextMessage, []byte(Ping)) + require.NoError(t, err) } } From f529bd2e9280b5510304ce503c64a9a28767ea8e Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 25 Sep 2024 14:06:54 +1000 Subject: [PATCH 089/138] remove url because its throwing errors on CI builds --- exchanges/stream/websocket_test.go | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 47efd5ccf2d..30ac82df2e7 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -28,9 +28,8 @@ import ( ) const ( - websocketTestURL = "wss://ws.bitmex.com/realtime" - useProxyTests = false // Disabled by default. Freely available proxy servers that work all the time are difficult to find - proxyURL = "http://212.186.171.4:80" // Replace with a usable proxy server + useProxyTests = false // Disabled by default. Freely available proxy servers that work all the time are difficult to find + proxyURL = "http://212.186.171.4:80" // Replace with a usable proxy server ) var ( @@ -806,8 +805,12 @@ func readMessages(t *testing.T, wc *WebsocketConnection) { // TestSetupPingHandler logic test func TestSetupPingHandler(t *testing.T) { t.Parallel() + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) + defer mock.Close() + wc := &WebsocketConnection{ - URL: websocketTestURL, + URL: "ws" + mock.URL[len("http"):] + "/ws", ResponseMaxLimit: time.Second * 5, Match: NewMatch(), Wg: &sync.WaitGroup{}, @@ -850,8 +853,12 @@ func TestSetupPingHandler(t *testing.T) { // TestParseBinaryResponse logic test func TestParseBinaryResponse(t *testing.T) { t.Parallel() + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) + defer mock.Close() + wc := &WebsocketConnection{ - URL: websocketTestURL, + URL: "ws" + mock.URL[len("http"):] + "/ws", ResponseMaxLimit: time.Second * 5, Match: NewMatch(), } @@ -1242,7 +1249,10 @@ func TestWebsocketConnectionShutdown(t *testing.T) { err = wc.Dial(&websocket.Dialer{}, nil) assert.ErrorContains(t, err, "malformed ws or wss URL", "Dial must error correctly") - wc.URL = websocketTestURL + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) + defer mock.Close() + + wc.URL = "ws" + mock.URL[len("http"):] + "/ws" err = wc.Dial(&websocket.Dialer{}, nil) require.NoError(t, err, "Dial must not error") From 5d0a7f7b7f8130fc5a29dda5d85ea690965a0885 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 25 Sep 2024 14:38:56 +1000 Subject: [PATCH 090/138] connections drop all the time, don't need to worry about not being able to echo back ws data as it can be easily reviewed _test file side. --- internal/testing/websocket/mock.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/internal/testing/websocket/mock.go b/internal/testing/websocket/mock.go index fb93d9becbb..0e553b0c2ad 100644 --- a/internal/testing/websocket/mock.go +++ b/internal/testing/websocket/mock.go @@ -40,11 +40,8 @@ func WsMockUpgrader(tb testing.TB, w http.ResponseWriter, r *http.Request, wsHan } } -// EchoHandler is a simple echo function after a read +// EchoHandler is a simple echo function after a read, this doesn't need to worry if writing to the connection fails func EchoHandler(p []byte, c *websocket.Conn) error { - err := c.WriteMessage(websocket.TextMessage, p) - if err != nil { - return err - } + _ = c.WriteMessage(websocket.TextMessage, p) return nil } From 56cb431e6ac2f06f8bd7bc5b0dad8b5aed6dce47 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 25 Sep 2024 14:49:00 +1000 Subject: [PATCH 091/138] remove another superfluous url thats not really set up for this --- exchanges/stream/websocket_test.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 30ac82df2e7..6a86a5319c1 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -720,9 +720,13 @@ func TestSendMessage(t *testing.T) { func TestSendMessageReturnResponse(t *testing.T) { t.Parallel() + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) + defer mock.Close() + wc := &WebsocketConnection{ Verbose: true, - URL: "wss://ws.kraken.com", + URL: "ws" + mock.URL[len("http"):] + "/ws", ResponseMaxLimit: time.Second * 5, Match: NewMatch(), } From ca4999eb1c805cd69830447e85e1985280130efa Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 26 Sep 2024 09:59:46 +1000 Subject: [PATCH 092/138] spawn overwatch routine when there is no errors, inline checker instead of waiting for a time period, add sleep inline with echo handler as this is really quick and wanted to ensure that latency is handing correctly --- exchanges/stream/websocket.go | 16 +++++++++++----- exchanges/stream/websocket_test.go | 26 +++++++++++--------------- internal/testing/websocket/mock.go | 2 ++ 3 files changed, 24 insertions(+), 20 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 5451bb882f0..001bcc28391 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -341,11 +341,6 @@ func (w *Websocket) connect() error { go w.monitorFrame(&w.Wg, w.monitorData) go w.monitorFrame(&w.Wg, w.monitorTraffic) - if w.connectionMonitorRunning.CompareAndSwap(false, true) { - // This oversees all connections and does not need to be part of wait group management. - go w.monitorFrame(nil, w.monitorConnection) - } - if !w.useMultiConnectionManagement { if w.connector == nil { return fmt.Errorf("%v %w", w.exchangeName, errNoConnectFunc) @@ -357,6 +352,11 @@ func (w *Websocket) connect() error { } w.setState(connectedState) + if w.connectionMonitorRunning.CompareAndSwap(false, true) { + // This oversees all connections and does not need to be part of wait group management. + go w.monitorFrame(nil, w.monitorConnection) + } + subs, err := w.GenerateSubs() // regenerate state on new connection if err != nil { return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) @@ -474,6 +474,12 @@ func (w *Websocket) connect() error { // All subscriptions have been sent and stored. All data received is being // handled by the appropriate data handler. w.setState(connectedState) + + if w.connectionMonitorRunning.CompareAndSwap(false, true) { + // This oversees all connections and does not need to be part of wait group management. + go w.monitorFrame(nil, w.monitorConnection) + } + return nil } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 6a86a5319c1..38c4ced9f9c 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -194,29 +194,25 @@ func TestConnectionMessageErrors(t *testing.T) { err = ws.Connect() require.NoError(t, err, "Connect must not error") - c := func(tb *assert.CollectT) { - select { - case v, ok := <-ws.ToRoutine: - require.True(tb, ok, "ToRoutine should not be closed on us") - switch err := v.(type) { - case *websocket.CloseError: - assert.Equal(tb, "SpecialText", err.Text, "Should get correct Close Error") - case error: - assert.ErrorIs(tb, err, errDastardlyReason, "Should get the correct error") - default: - assert.Failf(tb, "Wrong data type sent to ToRoutine", "Got type: %T", err) - } + checkToRoutineResult := func(t *testing.T) { + v, ok := <-ws.ToRoutine + require.True(t, ok, "ToRoutine should not be closed on us") + switch err := v.(type) { + case *websocket.CloseError: + assert.Equal(t, "SpecialText", err.Text, "Should get correct Close Error") + case error: + assert.ErrorIs(t, err, errDastardlyReason, "Should get the correct error") default: - assert.Fail(tb, "Nothing available on ToRoutine") + assert.Failf(t, "Wrong data type sent to ToRoutine", "Got type: %T", err) } } ws.TrafficAlert <- struct{}{} ws.ReadMessageErrors <- errDastardlyReason - assert.EventuallyWithT(t, c, 2*time.Second, 10*time.Millisecond, "Should get an error down the routine") + checkToRoutineResult(t) ws.ReadMessageErrors <- &websocket.CloseError{Code: 1006, Text: "SpecialText"} - assert.EventuallyWithT(t, c, 2*time.Second, 10*time.Millisecond, "Should get an error down the routine") + checkToRoutineResult(t) // Test individual connection defined functions require.NoError(t, ws.Shutdown()) diff --git a/internal/testing/websocket/mock.go b/internal/testing/websocket/mock.go index 0e553b0c2ad..ebaacf0a449 100644 --- a/internal/testing/websocket/mock.go +++ b/internal/testing/websocket/mock.go @@ -3,6 +3,7 @@ package websocket import ( "net/http" "testing" + "time" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" @@ -42,6 +43,7 @@ func WsMockUpgrader(tb testing.TB, w http.ResponseWriter, r *http.Request, wsHan // EchoHandler is a simple echo function after a read, this doesn't need to worry if writing to the connection fails func EchoHandler(p []byte, c *websocket.Conn) error { + time.Sleep(time.Nanosecond) // Shift clock to simulate time passing _ = c.WriteMessage(websocket.TextMessage, p) return nil } From 4240a0a1189b4e76c465031f8a95fc2ca7c6f240 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 26 Sep 2024 10:04:59 +1000 Subject: [PATCH 093/138] linter: fixerino uperino --- exchanges/stream/websocket_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 38c4ced9f9c..2904bccadee 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -195,6 +195,7 @@ func TestConnectionMessageErrors(t *testing.T) { require.NoError(t, err, "Connect must not error") checkToRoutineResult := func(t *testing.T) { + t.Helper() v, ok := <-ws.ToRoutine require.True(t, ok, "ToRoutine should not be closed on us") switch err := v.(type) { From 8d6febcc122e905f6512b4763a21d4e299db94d7 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Mon, 30 Sep 2024 12:54:10 +1000 Subject: [PATCH 094/138] glorious: panix --- exchanges/stream/websocket.go | 3 +-- exchanges/stream/websocket_connection.go | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 001bcc28391..33f17a37222 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -191,7 +191,6 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { } w.trafficTimeout = s.ExchangeConfig.WebsocketTrafficTimeout - w.ShutdownC = make(chan struct{}) w.SetCanUseAuthenticatedEndpoints(s.ExchangeConfig.API.AuthenticatedWebsocketSupport) if err := w.Orderbook.Setup(s.ExchangeConfig, &s.OrderbookBufferConfig, w.DataHandler); err != nil { @@ -301,7 +300,7 @@ func (w *Websocket) getConnectionFromSetup(c *ConnectionSetup) *WebsocketConnect ResponseMaxLimit: c.ResponseMaxLimit, Traffic: w.TrafficAlert, readMessageErrors: w.ReadMessageErrors, - shutdown: make(chan struct{}), // Call shutdown to close the connection + shutdown: w.ShutdownC, Wg: &w.Wg, Match: w.Match, RateLimit: c.RateLimit, diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index cee93985dc8..1f6f5e6019a 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -272,7 +272,6 @@ func (w *WebsocketConnection) Shutdown() error { return nil } w.setConnectedStatus(false) - close(w.shutdown) w.writeControl.Lock() defer w.writeControl.Unlock() return w.Connection.NetConn().Close() From 3a246402b6c3dd2f39b5a217bdbf30ce99d2c953 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Tue, 1 Oct 2024 14:14:33 +1000 Subject: [PATCH 095/138] linter: things --- exchanges/gateio/gateio_test.go | 3 +-- exchanges/gateio/gateio_websocket.go | 8 ++++---- exchanges/gateio/gateio_wrapper.go | 4 ++-- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index a78cabc497a..5e8ee348d4d 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -3009,9 +3009,8 @@ func TestSubscribe(t *testing.T) { subs, err := g.Features.Subscriptions.ExpandTemplates(g) require.NoError(t, err, "ExpandTemplates must not error") g.Features.Subscriptions = subscription.List{} - err = g.SpotSubscribe(context.Background(), &DummyConnection{}, subs) + err = g.Subscribe(context.Background(), &DummyConnection{}, subs) require.NoError(t, err, "Subscribe must not error") - } func TestGenerateDeliveryFuturesDefaultSubscriptions(t *testing.T) { diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 5e96d0e6ab0..8dd36d8e3e7 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -706,12 +706,12 @@ func (g *Gateio) manageSubReq(ctx context.Context, event string, conn stream.Con } // Subscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) SpotSubscribe(ctx context.Context, conn stream.Connection, subs subscription.List) error { +func (g *Gateio) Subscribe(ctx context.Context, conn stream.Connection, subs subscription.List) error { return g.manageSubs(ctx, subscribeEvent, conn, subs) } // Unsubscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) SpotUnsubscribe(ctx context.Context, conn stream.Connection, subs subscription.List) error { +func (g *Gateio) Unsubscribe(ctx context.Context, conn stream.Connection, subs subscription.List) error { return g.manageSubs(ctx, unsubscribeEvent, conn, subs) } @@ -795,9 +795,9 @@ func (g *Gateio) handleSubscription(ctx context.Context, conn stream.Connection, continue } if event == subscribeEvent { - err = g.Websocket.AddSuccessfulSubscriptions(conn, channelsToSubscribe[k]) + err = common.AppendError(err, g.Websocket.AddSuccessfulSubscriptions(conn, channelsToSubscribe[k])) } else { - err = g.Websocket.RemoveSubscriptions(conn, channelsToSubscribe[k]) + err = common.AppendError(err, g.Websocket.RemoveSubscriptions(conn, channelsToSubscribe[k])) } } } diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 4c32ccfd6af..6c3377067d6 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -214,8 +214,8 @@ func (g *Gateio) Setup(exch *config.Exchange) error { ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, Handler: g.WsHandleSpotData, - Subscriber: g.SpotSubscribe, - Unsubscriber: g.SpotUnsubscribe, + Subscriber: g.Subscribe, + Unsubscriber: g.Unsubscribe, GenerateSubscriptions: g.generateSubscriptionsSpot, Connector: g.WsConnectSpot, BespokeGenerateMessageID: g.GenerateWebsocketMessageID, From 1eef208a5ab505d6dd5a6eba74f0e004d171cfae Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Tue, 1 Oct 2024 14:15:18 +1000 Subject: [PATCH 096/138] whoops --- exchanges/gateio/gateio_websocket.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 8dd36d8e3e7..d00f03b451f 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -795,9 +795,9 @@ func (g *Gateio) handleSubscription(ctx context.Context, conn stream.Connection, continue } if event == subscribeEvent { - err = common.AppendError(err, g.Websocket.AddSuccessfulSubscriptions(conn, channelsToSubscribe[k])) + errs = common.AppendError(errs, g.Websocket.AddSuccessfulSubscriptions(conn, channelsToSubscribe[k])) } else { - err = common.AppendError(err, g.Websocket.RemoveSubscriptions(conn, channelsToSubscribe[k])) + errs = common.AppendError(errs, g.Websocket.RemoveSubscriptions(conn, channelsToSubscribe[k])) } } } From b5284073e673b22d799658cd2c49b3ad530207c4 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 2 Oct 2024 10:33:24 +1000 Subject: [PATCH 097/138] dont need to make consecutive Unix() calls --- exchanges/gateio/gateio_websocket_request_spot.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/exchanges/gateio/gateio_websocket_request_spot.go b/exchanges/gateio/gateio_websocket_request_spot.go index 04f2410e975..87bbc2bab27 100644 --- a/exchanges/gateio/gateio_websocket_request_spot.go +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -42,8 +42,8 @@ func (g *Gateio) WebsocketLogin(ctx context.Context, conn stream.Connection, cha return nil, err } - tn := time.Now() - msg := "api\n" + channel + "\n" + "\n" + strconv.FormatInt(tn.Unix(), 10) + tn := time.Now().Unix() + msg := "api\n" + channel + "\n" + "\n" + strconv.FormatInt(tn, 10) mac := hmac.New(sha512.New, []byte(creds.Secret)) if _, err = mac.Write([]byte(msg)); err != nil { return nil, err @@ -54,10 +54,10 @@ func (g *Gateio) WebsocketLogin(ctx context.Context, conn stream.Connection, cha RequestID: strconv.FormatInt(conn.GenerateMessageID(false), 10), APIKey: creds.Key, Signature: signature, - Timestamp: strconv.FormatInt(tn.Unix(), 10), + Timestamp: strconv.FormatInt(tn, 10), } - req := WebsocketRequest{Time: tn.Unix(), Channel: channel, Event: "api", Payload: payload} + req := WebsocketRequest{Time: tn, Channel: channel, Event: "api", Payload: payload} resp, err := conn.SendMessageReturnResponse(ctx, request.Unset, req.Payload.RequestID, req) if err != nil { @@ -247,9 +247,9 @@ func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel string, connS return err } - tn := time.Now() + tn := time.Now().Unix() req := &WebsocketRequest{ - Time: tn.Unix(), + Time: tn, Channel: channel, Event: "api", Payload: WebsocketPayload{ @@ -257,7 +257,7 @@ func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel string, connS // response. RequestID: strconv.FormatInt(conn.GenerateMessageID(false), 10), RequestParam: paramPayload, - Timestamp: strconv.FormatInt(tn.Unix(), 10), + Timestamp: strconv.FormatInt(tn, 10), }, } From ee1cb94c02df2131fad5554b098d67e089a86f6c Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 3 Oct 2024 16:09:43 +1000 Subject: [PATCH 098/138] websocket: fix potential panic on error and no responses and adding waitForResponses --- exchanges/stream/websocket_connection.go | 56 +++++++++++++----------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index dbd5ec5bf73..076d7962c91 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -315,36 +315,16 @@ func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, ep } start := time.Now() - err = w.SendRawMessage(ctx, epl, websocket.TextMessage, outbound) - if err != nil { + if err := w.SendRawMessage(ctx, epl, websocket.TextMessage, outbound); err != nil { return nil, err } - timeout := time.NewTimer(w.ResponseMaxLimit * time.Duration(expected)) - - resps := make([][]byte, 0, expected) - for err == nil && len(resps) < expected { - select { - case resp := <-ch: - resps = append(resps, resp) - case <-timeout.C: - w.Match.RemoveSignature(signature) - err = fmt.Errorf("%s %w %v", w.ExchangeName, ErrSignatureTimeout, signature) - case <-ctx.Done(): - w.Match.RemoveSignature(signature) - err = ctx.Err() - } - // Checks recently received message to determine if this is in fact the - // final message in a sequence of messages. - if len(isFinalMessage) == 1 && isFinalMessage[0](resps[len(resps)-1]) { - w.Match.RemoveSignature(signature) - break - } + resps, err := w.waitForResponses(ctx, signature, ch, expected, isFinalMessage...) + if err != nil { + return nil, err } - timeout.Stop() - - if err == nil && w.Reporter != nil { + if w.Reporter != nil { w.Reporter.Latency(w.ExchangeName, outbound, time.Since(start)) } @@ -358,6 +338,32 @@ func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, ep return resps, err } +// waitForResponses waits for N responses from a channel +func (w *WebsocketConnection) waitForResponses(ctx context.Context, signature any, ch <-chan []byte, expected int, isFinalMessage ...Inspector) ([][]byte, error) { + timeout := time.NewTimer(w.ResponseMaxLimit * time.Duration(expected)) + defer timeout.Stop() + + resps := make([][]byte, 0, expected) + for range expected { + select { + case resp := <-ch: + resps = append(resps, resp) + // Checks recently received message to determine if this is in fact the final message in a sequence of messages. + if len(isFinalMessage) == 1 && isFinalMessage[0](resp) { + w.Match.RemoveSignature(signature) + return resps, nil + } + case <-timeout.C: + w.Match.RemoveSignature(signature) + return nil, fmt.Errorf("%s %w %v", w.ExchangeName, ErrSignatureTimeout, signature) + case <-ctx.Done(): + w.Match.RemoveSignature(signature) + return nil, ctx.Err() + } + } + return resps, nil +} + func removeURLQueryString(url string) string { if index := strings.Index(url, "?"); index != -1 { return url[:index] From e5accaf60541d3cd5898730a0d3f67c151129c88 Mon Sep 17 00:00:00 2001 From: shazbert Date: Fri, 11 Oct 2024 04:42:18 +1100 Subject: [PATCH 099/138] rm json parser and handle in json package instead --- exchanges/gateio/gateio_types.go | 13 ++++++------ exchanges/gateio/gateio_websocket.go | 18 +++++++--------- .../gateio/gateio_websocket_request_spot.go | 21 +++++++------------ 3 files changed, 22 insertions(+), 30 deletions(-) diff --git a/exchanges/gateio/gateio_types.go b/exchanges/gateio/gateio_types.go index 2b3e558c29a..267928fd011 100644 --- a/exchanges/gateio/gateio_types.go +++ b/exchanges/gateio/gateio_types.go @@ -2012,12 +2012,13 @@ type WsEventResponse struct { // WsResponse represents generalized websocket push data from the server. type WsResponse struct { - ID int64 `json:"id"` - Time types.Time `json:"time"` - TimeMs types.Time `json:"time_ms"` - Channel string `json:"channel"` - Event string `json:"event"` - Result json.RawMessage `json:"result"` + ID int64 `json:"id"` + Time types.Time `json:"time"` + TimeMs types.Time `json:"time_ms"` + Channel string `json:"channel"` + Event string `json:"event"` + Result json.RawMessage `json:"result"` + RequestID string `json:"request_id"` } // WsTicker websocket ticker information. diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 8a4a5e2ab72..cb07d03706c 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -15,7 +15,6 @@ import ( "time" "github.com/Masterminds/sprig/v3" - "github.com/buger/jsonparser" "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/currency" @@ -115,19 +114,18 @@ func (g *Gateio) generateWsSignature(secret, event, channel string, t int64) (st // WsHandleSpotData handles spot data func (g *Gateio) WsHandleSpotData(_ context.Context, respRaw []byte) error { - if requestID, err := jsonparser.GetString(respRaw, "request_id"); err == nil && requestID != "" { - if !g.Websocket.Match.IncomingWithData(requestID, respRaw) { - return fmt.Errorf("gateio_websocket.go error - unable to match requestID %v", requestID) - } - return nil - } - var push WsResponse - err := json.Unmarshal(respRaw, &push) - if err != nil { + if err := json.Unmarshal(respRaw, &push); err != nil { return err } + if push.RequestID != "" { + if !g.Websocket.Match.IncomingWithData(push.RequestID, respRaw) { + return fmt.Errorf("gateio_websocket.go error - unable to match requestID %v", push.RequestID) + } + return nil + } + if push.Event == subscribeEvent || push.Event == unsubscribeEvent { if !g.Websocket.Match.IncomingWithData(push.ID, respRaw) { return fmt.Errorf("couldn't match subscription message with ID: %d", push.ID) diff --git a/exchanges/gateio/gateio_websocket_request_spot.go b/exchanges/gateio/gateio_websocket_request_spot.go index 87bbc2bab27..fbb45618ebc 100644 --- a/exchanges/gateio/gateio_websocket_request_spot.go +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -65,15 +65,13 @@ func (g *Gateio) WebsocketLogin(ctx context.Context, conn stream.Connection, cha } var inbound WebsocketAPIResponse - err = json.Unmarshal(resp, &inbound) - if err != nil { + if err := json.Unmarshal(resp, &inbound); err != nil { return nil, err } if inbound.Header.Status != "200" { var wsErr WebsocketErrors - err := json.Unmarshal(inbound.Data, &wsErr.Errors) - if err != nil { + if err := json.Unmarshal(inbound.Data, &wsErr.Errors); err != nil { return nil, err } return nil, fmt.Errorf("%s: %s", wsErr.Errors.Label, wsErr.Errors.Message) @@ -184,8 +182,7 @@ func (g *Gateio) WebsocketOrderCancelAllByPairSpot(ctx context.Context, pair cur } var resp []WebsocketOrderResponse - err := g.SendWebsocketRequest(ctx, "spot.order_cancel_cp", asset.Spot, params, &resp, 1) - return resp, err + return resp, g.SendWebsocketRequest(ctx, "spot.order_cancel_cp", asset.Spot, params, &resp, 1) } // WebsocketOrderAmendSpot amends an order via the websocket connection @@ -207,8 +204,7 @@ func (g *Gateio) WebsocketOrderAmendSpot(ctx context.Context, amend *WebsocketAm } var resp WebsocketOrderResponse - err := g.SendWebsocketRequest(ctx, "spot.order_amend", asset.Spot, amend, &resp, 1) - return &resp, err + return &resp, g.SendWebsocketRequest(ctx, "spot.order_amend", asset.Spot, amend, &resp, 1) } // WebsocketGetOrderStatusSpot gets the status of an order via the websocket connection @@ -231,8 +227,7 @@ func (g *Gateio) WebsocketGetOrderStatusSpot(ctx context.Context, orderID string } var resp WebsocketOrderResponse - err := g.SendWebsocketRequest(ctx, "spot.order_status", asset.Spot, params, &resp, 1) - return &resp, err + return &resp, g.SendWebsocketRequest(ctx, "spot.order_status", asset.Spot, params, &resp, 1) } // SendWebsocketRequest sends a websocket request to the exchange @@ -276,15 +271,13 @@ func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel string, connS // from that as the next response won't come anyway. endResponse := responses[len(responses)-1] - err = json.Unmarshal(endResponse, &inbound) - if err != nil { + if err := json.Unmarshal(endResponse, &inbound); err != nil { return err } if inbound.Header.Status != "200" { var wsErr WebsocketErrors - err = json.Unmarshal(inbound.Data, &wsErr) - if err != nil { + if err := json.Unmarshal(inbound.Data, &wsErr); err != nil { return err } return fmt.Errorf("%s: %s", wsErr.Errors.Label, wsErr.Errors.Message) From d950b2465775fe98a7d574dcdd304609420d7b0a Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 23 Oct 2024 11:38:16 +1100 Subject: [PATCH 100/138] linter: fix --- exchanges/stream/websocket_connection.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 076d7962c91..dc8a75cca03 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -315,7 +315,7 @@ func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, ep } start := time.Now() - if err := w.SendRawMessage(ctx, epl, websocket.TextMessage, outbound); err != nil { + if err = w.SendRawMessage(ctx, epl, websocket.TextMessage, outbound); err != nil { return nil, err } From 31339ed5155bc97ca9000e0d74818716fc617f56 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 23 Oct 2024 11:49:45 +1100 Subject: [PATCH 101/138] linter: fix again --- exchanges/stream/websocket_connection.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index dc8a75cca03..edc49d89b70 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -315,7 +315,8 @@ func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, ep } start := time.Now() - if err = w.SendRawMessage(ctx, epl, websocket.TextMessage, outbound); err != nil { + err = w.SendRawMessage(ctx, epl, websocket.TextMessage, outbound) + if err != nil { return nil, err } From 622c748d7e7158e49d8c6a55bbc82b9a34423d4c Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 23 Oct 2024 13:07:26 +1100 Subject: [PATCH 102/138] * change field name OutboundRequestSignature to WrapperDefinedConnectionSignature for agnostic inbound and outbound connections. * change method name GetOutboundConnection to GetConnection for agnostic inbound and outbound connections. * drop outbound field map for improved performance just using a range and field check (less complex as well) * change field name connections to connectionToWrapper for better clarity --- .../gateio/gateio_websocket_request_spot.go | 2 +- .../gateio_websocket_request_spot_test.go | 2 +- exchanges/gateio/gateio_wrapper.go | 82 +++++++++---------- exchanges/stream/stream_types.go | 9 +- exchanges/stream/websocket.go | 70 ++++++++-------- exchanges/stream/websocket_connection.go | 15 ++-- exchanges/stream/websocket_test.go | 31 ++++--- exchanges/stream/websocket_types.go | 10 +-- 8 files changed, 105 insertions(+), 116 deletions(-) diff --git a/exchanges/gateio/gateio_websocket_request_spot.go b/exchanges/gateio/gateio_websocket_request_spot.go index fbb45618ebc..e82231832fe 100644 --- a/exchanges/gateio/gateio_websocket_request_spot.go +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -237,7 +237,7 @@ func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel string, connS return err } - conn, err := g.Websocket.GetOutboundConnection(connSignature) + conn, err := g.Websocket.GetConnection(connSignature) if err != nil { return err } diff --git a/exchanges/gateio/gateio_websocket_request_spot_test.go b/exchanges/gateio/gateio_websocket_request_spot_test.go index 80bbb86c88d..af5d02e8723 100644 --- a/exchanges/gateio/gateio_websocket_request_spot_test.go +++ b/exchanges/gateio/gateio_websocket_request_spot_test.go @@ -29,7 +29,7 @@ func TestWebsocketLogin(t *testing.T) { testexch.UpdatePairsOnce(t, g) g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes - demonstrationConn, err := g.Websocket.GetOutboundConnection(asset.Spot) + demonstrationConn, err := g.Websocket.GetConnection(asset.Spot) require.NoError(t, err) got, err := g.WebsocketLogin(context.Background(), demonstrationConn, "spot.login") diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 3cbd0d38c09..a84279110de 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -209,18 +209,18 @@ func (g *Gateio) Setup(exch *config.Exchange) error { } // Spot connection err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ - URL: gateioWebsocketEndpoint, - RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - Handler: g.WsHandleSpotData, - Subscriber: g.Subscribe, - Unsubscriber: g.Unsubscribe, - GenerateSubscriptions: g.generateSubscriptionsSpot, - Connector: g.WsConnectSpot, - Authenticate: g.AuthenticateSpot, - OutboundRequestSignature: asset.Spot, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + URL: gateioWebsocketEndpoint, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: g.WsHandleSpotData, + Subscriber: g.Subscribe, + Unsubscriber: g.Unsubscribe, + GenerateSubscriptions: g.generateSubscriptionsSpot, + Connector: g.WsConnectSpot, + Authenticate: g.AuthenticateSpot, + WrapperDefinedConnectionSignature: asset.Spot, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { return err @@ -234,12 +234,12 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Handler: func(ctx context.Context, incoming []byte) error { return g.WsHandleFuturesData(ctx, incoming, asset.Futures) }, - Subscriber: g.FuturesSubscribe, - Unsubscriber: g.FuturesUnsubscribe, - GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.USDT) }, - Connector: g.WsFuturesConnect, - OutboundRequestSignature: asset.USDTMarginedFutures, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + Subscriber: g.FuturesSubscribe, + Unsubscriber: g.FuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.USDT) }, + Connector: g.WsFuturesConnect, + WrapperDefinedConnectionSignature: asset.USDTMarginedFutures, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { return err @@ -254,12 +254,12 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Handler: func(ctx context.Context, incoming []byte) error { return g.WsHandleFuturesData(ctx, incoming, asset.Futures) }, - Subscriber: g.FuturesSubscribe, - Unsubscriber: g.FuturesUnsubscribe, - GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.BTC) }, - Connector: g.WsFuturesConnect, - OutboundRequestSignature: asset.CoinMarginedFutures, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + Subscriber: g.FuturesSubscribe, + Unsubscriber: g.FuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.BTC) }, + Connector: g.WsFuturesConnect, + WrapperDefinedConnectionSignature: asset.CoinMarginedFutures, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { return err @@ -275,12 +275,12 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Handler: func(ctx context.Context, incoming []byte) error { return g.WsHandleFuturesData(ctx, incoming, asset.DeliveryFutures) }, - Subscriber: g.DeliveryFuturesSubscribe, - Unsubscriber: g.DeliveryFuturesUnsubscribe, - GenerateSubscriptions: g.GenerateDeliveryFuturesDefaultSubscriptions, - Connector: g.WsDeliveryFuturesConnect, - OutboundRequestSignature: asset.DeliveryFutures, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + Subscriber: g.DeliveryFuturesSubscribe, + Unsubscriber: g.DeliveryFuturesUnsubscribe, + GenerateSubscriptions: g.GenerateDeliveryFuturesDefaultSubscriptions, + Connector: g.WsDeliveryFuturesConnect, + WrapperDefinedConnectionSignature: asset.DeliveryFutures, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { return err @@ -288,17 +288,17 @@ func (g *Gateio) Setup(exch *config.Exchange) error { // Futures connection - Options return g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ - URL: optionsWebsocketURL, - RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - Handler: g.WsHandleOptionsData, - Subscriber: g.OptionsSubscribe, - Unsubscriber: g.OptionsUnsubscribe, - GenerateSubscriptions: g.GenerateOptionsDefaultSubscriptions, - Connector: g.WsOptionsConnect, - OutboundRequestSignature: asset.Options, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + URL: optionsWebsocketURL, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: g.WsHandleOptionsData, + Subscriber: g.OptionsSubscribe, + Unsubscriber: g.OptionsUnsubscribe, + GenerateSubscriptions: g.GenerateOptionsDefaultSubscriptions, + Connector: g.WsOptionsConnect, + WrapperDefinedConnectionSignature: asset.Options, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) } diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index 3860ceb7db7..803081fecae 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -84,11 +84,10 @@ type ConnectionSetup struct { // handle the authentication process and return an error if the // authentication fails. Authenticate func(ctx context.Context, conn Connection) error - // OutboundRequestSignature is any type that will match outbound - // requests to this specific connection. This could be an asset type - // `asset.Spot`, a string type denoting the individual URL, an - // authenticated or unauthenticated string or a mixture of these. - OutboundRequestSignature any + // WrapperDefinedConnectionSignature is any type that will match to a specfic connection. This could be an asset + // type `asset.Spot`, a string type denoting the individual URL, an authenticated or unauthenticated string or a + // mixture of these. + WrapperDefinedConnectionSignature any } // ConnectionWrapper contains the connection setup details to be used when diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 9e32369f715..023ca46b26b 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -30,7 +30,6 @@ var ( ErrNoMessageListener = errors.New("websocket listener not found for message") ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature") ErrRequestRouteNotFound = errors.New("request route not found") - ErrRequestRouteNotSet = errors.New("request route not set") ErrSignatureNotSet = errors.New("signature not set") ErrRequestPayloadNotSet = errors.New("request payload not set") ) @@ -69,6 +68,7 @@ var ( errCannotChangeConnectionURL = errors.New("cannot change connection URL when using multi connection management") errExchangeConfigEmpty = errors.New("exchange config is empty") errCannotObtainOutboundConnection = errors.New("cannot obtain outbound connection") + errConnectionSignatureNotSet = errors.New("connection signature not set") ) var globalReporter Reporter @@ -90,13 +90,12 @@ func NewWebsocket() *Websocket { // after subscriptions are made but before the connectionMonitor has // started. This allows the error to be read and handled in the // connectionMonitor and start a connection cycle again. - ReadMessageErrors: make(chan error, 1), - Match: NewMatch(), - subscriptions: subscription.NewStore(), - features: &protocol.Features{}, - Orderbook: buffer.Orderbook{}, - connections: make(map[Connection]*ConnectionWrapper), - outbound: make(map[any]*ConnectionWrapper), + ReadMessageErrors: make(chan error, 1), + Match: NewMatch(), + subscriptions: subscription.NewStore(), + features: &protocol.Features{}, + Orderbook: buffer.Orderbook{}, + connectionToWrapper: make(map[Connection]*ConnectionWrapper), } } @@ -269,16 +268,14 @@ func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { // Below allows for multiple connections to the same URL with different outbound request signatures. This // allows for easier determination of inbound and outbound messages. e.g. Gateio cross_margin, margin on // a spot connection. - if w.connectionManager[x].Setup.URL == c.URL && c.OutboundRequestSignature == w.connectionManager[x].Setup.OutboundRequestSignature { + if w.connectionManager[x].Setup.URL == c.URL && c.WrapperDefinedConnectionSignature == w.connectionManager[x].Setup.WrapperDefinedConnectionSignature { return fmt.Errorf("%w: %w", errConnSetup, errConnectionWrapperDuplication) } } - w.connectionManager = append(w.connectionManager, &ConnectionWrapper{ Setup: c, Subscriptions: subscription.NewStore(), }) - w.outbound[c.OutboundRequestSignature] = w.connectionManager[len(w.connectionManager)-1] return nil } @@ -432,7 +429,7 @@ func (w *Websocket) connect() error { break } - w.connections[conn] = w.connectionManager[i] + w.connectionToWrapper[conn] = w.connectionManager[i] w.connectionManager[i].Connection = conn w.Wg.Add(1) @@ -473,7 +470,7 @@ func (w *Websocket) connect() error { } w.connectionManager[x].Subscriptions.Clear() } - clear(w.connections) + clear(w.connectionToWrapper) w.setState(disconnectedState) // Flip from connecting to disconnected. // Drain residual error in the single buffered channel, this mitigates @@ -561,7 +558,7 @@ func (w *Websocket) shutdown() error { } } // Clean map of old connections - clear(w.connections) + clear(w.connectionToWrapper) if w.Conn != nil { if err := w.Conn.Shutdown(); err != nil { @@ -652,7 +649,7 @@ func (w *Websocket) FlushChannels() error { } w.Wg.Add(1) go w.Reader(context.TODO(), conn, w.connectionManager[x].Setup.Handler) - w.connections[conn] = w.connectionManager[x] + w.connectionToWrapper[conn] = w.connectionManager[x] w.connectionManager[x].Connection = conn } @@ -671,7 +668,7 @@ func (w *Websocket) FlushChannels() error { // If there are no subscriptions to subscribe to, close the connection as it is no longer needed. if w.connectionManager[x].Subscriptions.Len() == 0 { - delete(w.connections, w.connectionManager[x].Connection) // Remove from lookup map + delete(w.connectionToWrapper, w.connectionManager[x].Connection) // Remove from lookup map if err := w.connectionManager[x].Connection.Shutdown(); err != nil { log.Warnf(log.WebsocketMgr, "%v websocket: failed to shutdown connection: %v", w.exchangeName, err) } @@ -832,7 +829,7 @@ func (w *Websocket) GetName() string { // and the new subscription list when pairs are disabled or enabled. func (w *Websocket) GetChannelDifference(conn Connection, newSubs subscription.List) (sub, unsub subscription.List) { var subscriptionStore **subscription.Store - if wrapper, ok := w.connections[conn]; ok && conn != nil { + if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { subscriptionStore = &wrapper.Subscriptions } else { subscriptionStore = &w.subscriptions @@ -848,7 +845,7 @@ func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.L if len(channels) == 0 { return nil // No channels to unsubscribe from is not an error } - if wrapper, ok := w.connections[conn]; ok && conn != nil { + if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { return w.unsubscribe(wrapper.Subscriptions, channels, func(channels subscription.List) error { return wrapper.Setup.Unsubscriber(context.TODO(), conn, channels) }) @@ -894,7 +891,7 @@ func (w *Websocket) SubscribeToChannels(conn Connection, subs subscription.List) return err } - if wrapper, ok := w.connections[conn]; ok && conn != nil { + if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { return wrapper.Setup.Subscriber(context.TODO(), conn, subs) } @@ -915,7 +912,7 @@ func (w *Websocket) AddSubscriptions(conn Connection, subs ...*subscription.Subs return fmt.Errorf("%w: AddSubscriptions called on nil Websocket", common.ErrNilPointer) } var subscriptionStore **subscription.Store - if wrapper, ok := w.connections[conn]; ok && conn != nil { + if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { subscriptionStore = &wrapper.Subscriptions } else { subscriptionStore = &w.subscriptions @@ -945,7 +942,7 @@ func (w *Websocket) AddSuccessfulSubscriptions(conn Connection, subs ...*subscri } var subscriptionStore **subscription.Store - if wrapper, ok := w.connections[conn]; ok && conn != nil { + if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { subscriptionStore = &wrapper.Subscriptions } else { subscriptionStore = &w.subscriptions @@ -974,7 +971,7 @@ func (w *Websocket) RemoveSubscriptions(conn Connection, subs ...*subscription.S } var subscriptionStore *subscription.Store - if wrapper, ok := w.connections[conn]; ok && conn != nil { + if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { subscriptionStore = wrapper.Subscriptions } else { subscriptionStore = w.subscriptions @@ -1061,7 +1058,7 @@ func checkWebsocketURL(s string) error { // The subscription state is not considered when counting existing subscriptions func (w *Websocket) checkSubscriptions(conn Connection, subs subscription.List) error { var subscriptionStore *subscription.Store - if wrapper, ok := w.connections[conn]; ok && conn != nil { + if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { subscriptionStore = wrapper.Subscriptions } else { subscriptionStore = w.subscriptions @@ -1261,15 +1258,15 @@ func signalReceived(ch chan struct{}) bool { } } -// GetOutboundConnection returns a connection specifically for outbound requests -// for multi connection management. -func (w *Websocket) GetOutboundConnection(connSignature any) (Connection, error) { +// GetConnection returns a connection by connection signature (defined in wrapper setup) for request and response +// handling in a multi connection context. +func (w *Websocket) GetConnection(connSignature any) (Connection, error) { if w == nil { return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, w) } - if connSignature == "" { - return nil, ErrRequestRouteNotSet + if connSignature == nil { + return nil, errConnectionSignatureNotSet } w.m.Lock() @@ -1283,14 +1280,15 @@ func (w *Websocket) GetOutboundConnection(connSignature any) (Connection, error) return nil, fmt.Errorf("%s: multi connection management not enabled %w please use exported Conn and AuthConn fields", w.exchangeName, errCannotObtainOutboundConnection) } - wrapper, ok := w.outbound[connSignature] - if !ok { - return nil, fmt.Errorf("%s: %w: %v", w.exchangeName, ErrRequestRouteNotFound, connSignature) - } - - if wrapper.Connection == nil { - return nil, fmt.Errorf("%s: %s %w: %v", w.exchangeName, wrapper.Setup.URL, ErrNotConnected, connSignature) + // Opted to range and not have a map, as connection level wrappers will be limited. + for _, wrapper := range w.connectionManager { + if wrapper.Setup.WrapperDefinedConnectionSignature == connSignature { + if wrapper.Connection == nil { + return nil, fmt.Errorf("%s: %s %w: %v", w.exchangeName, wrapper.Setup.URL, ErrNotConnected, connSignature) + } + return wrapper.Connection, nil + } } - return wrapper.Connection, nil + return nil, fmt.Errorf("%s: %w: %v", w.exchangeName, ErrRequestRouteNotFound, connSignature) } diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index edc49d89b70..ae8de076acf 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -329,13 +329,6 @@ func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, ep w.Reporter.Latency(w.ExchangeName, outbound, time.Since(start)) } - // Only check context verbosity. If the exchange is verbose, it will log the responses in the ReadMessage() call. - if request.IsVerbose(ctx, false) { - for i := range resps { - log.Debugf(log.WebsocketMgr, "%v %v: Received response [%d/%d]: %v", w.ExchangeName, removeURLQueryString(w.URL), i+1, len(resps), string(resps[i])) - } - } - return resps, err } @@ -362,6 +355,14 @@ func (w *WebsocketConnection) waitForResponses(ctx context.Context, signature an return nil, ctx.Err() } } + + // Only check context verbosity. If the exchange is verbose, it will log the responses in the ReadMessage() call. + if request.IsVerbose(ctx, false) { + for i := range resps { + log.Debugf(log.WebsocketMgr, "%v %v: Received response [%d/%d]: %v", w.ExchangeName, removeURLQueryString(w.URL), i+1, len(resps), string(resps[i])) + } + } + return resps, nil } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index e7421690bb7..a6bc1646ab3 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -467,7 +467,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { require.NoError(t, multi.SetupNewConnection(amazingCandidate)) amazingConn := multi.getConnectionFromSetup(amazingCandidate) - multi.connections = map[Connection]*ConnectionWrapper{ + multi.connectionToWrapper = map[Connection]*ConnectionWrapper{ amazingConn: multi.connectionManager[0], } @@ -979,7 +979,7 @@ func TestGetChannelDifference(t *testing.T) { require.Equal(t, 1, len(subs)) require.Empty(t, unsubs, "Should get no unsubs") - w.connections = map[Connection]*ConnectionWrapper{ + w.connectionToWrapper = map[Connection]*ConnectionWrapper{ sweetConn: {Setup: &ConnectionSetup{URL: "ws://localhost:8080/ws"}}, } @@ -992,7 +992,7 @@ func TestGetChannelDifference(t *testing.T) { require.Equal(t, 1, len(subs)) require.Empty(t, unsubs, "Should get no unsubs") - err := w.connections[sweetConn].Subscriptions.Add(&subscription.Subscription{Channel: subscription.CandlesChannel}) + err := w.connectionToWrapper[sweetConn].Subscriptions.Add(&subscription.Subscription{Channel: subscription.CandlesChannel}) require.NoError(t, err) subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}}) @@ -1489,42 +1489,39 @@ func TestMonitorTraffic(t *testing.T) { require.False(t, innerShell()) } -func TestGetOutboundConnection(t *testing.T) { +func TestGetConnection(t *testing.T) { t.Parallel() var ws *Websocket - _, err := ws.GetOutboundConnection("") + _, err := ws.GetConnection(nil) require.ErrorIs(t, err, common.ErrNilPointer) ws = &Websocket{} - _, err = ws.GetOutboundConnection("") - require.ErrorIs(t, err, ErrRequestRouteNotSet) - _, err = ws.GetOutboundConnection("testURL") + _, err = ws.GetConnection(nil) + require.ErrorIs(t, err, errConnectionSignatureNotSet) + + _, err = ws.GetConnection("testURL") require.ErrorIs(t, err, ErrNotConnected) ws.setState(connectedState) - _, err = ws.GetOutboundConnection("testURL") + _, err = ws.GetConnection("testURL") require.ErrorIs(t, err, errCannotObtainOutboundConnection) ws.useMultiConnectionManagement = true - _, err = ws.GetOutboundConnection("testURL") + _, err = ws.GetConnection("testURL") require.ErrorIs(t, err, ErrRequestRouteNotFound) ws.connectionManager = []*ConnectionWrapper{{ - Setup: &ConnectionSetup{URL: "testURL"}, + Setup: &ConnectionSetup{WrapperDefinedConnectionSignature: "testURL", URL: "testURL"}, }} - ws.outbound = map[any]*ConnectionWrapper{ - "testURL": ws.connectionManager[0], - } - - _, err = ws.GetOutboundConnection("testURL") + _, err = ws.GetConnection("testURL") require.ErrorIs(t, err, ErrNotConnected) expected := &WebsocketConnection{} ws.connectionManager[0].Connection = expected - conn, err := ws.GetOutboundConnection("testURL") + conn, err := ws.GetConnection("testURL") require.NoError(t, err) assert.Same(t, expected, conn) } diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index a25116d6529..5152ff2c342 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -55,14 +55,8 @@ type Websocket struct { // for exchanges that differentiate between trading pairs by using different connection endpoints or protocols for various asset classes. // If an exchange does not require such differentiation, all connections may be managed under a single ConnectionWrapper. connectionManager []*ConnectionWrapper - // connections holds a look up table for all connections to their corresponding ConnectionWrapper and subscription holder - connections map[Connection]*ConnectionWrapper - // outbound is map holding wrapper specific signatures to an active - // connection for outbound messaging. Wrapper specific connections - // might be asset specific e.g. spot, margin, futures or - // authenticated/unauthenticated or a mix of both. This map is used - // to send messages to the correct connection. - outbound map[any]*ConnectionWrapper + // connectionToWrapper holds a look up table for all connections to their corresponding ConnectionWrapper and subscription holder + connectionToWrapper map[Connection]*ConnectionWrapper subscriptions *subscription.Store From b71d3c00c195211afa4901e65dd392ae460909b4 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 23 Oct 2024 13:11:36 +1100 Subject: [PATCH 103/138] spells and magic and wands --- exchanges/stream/stream_types.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index 803081fecae..45c7d0d6534 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -84,7 +84,7 @@ type ConnectionSetup struct { // handle the authentication process and return an error if the // authentication fails. Authenticate func(ctx context.Context, conn Connection) error - // WrapperDefinedConnectionSignature is any type that will match to a specfic connection. This could be an asset + // WrapperDefinedConnectionSignature is any type that will match to a specific connection. This could be an asset // type `asset.Spot`, a string type denoting the individual URL, an authenticated or unauthenticated string or a // mixture of these. WrapperDefinedConnectionSignature any From 2bd207f9b771ed723a3ee9d179372a89bd21c95b Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 25 Oct 2024 12:06:30 +1100 Subject: [PATCH 104/138] glorious: nits --- .../gateio/gateio_websocket_request_spot.go | 22 +++---------------- .../gateio_websocket_request_spot_test.go | 8 +++---- .../gateio/gateio_websocket_request_types.go | 11 ++++++++-- exchanges/stream/stream_types.go | 2 +- exchanges/stream/websocket_connection.go | 8 +++---- 5 files changed, 21 insertions(+), 30 deletions(-) diff --git a/exchanges/gateio/gateio_websocket_request_spot.go b/exchanges/gateio/gateio_websocket_request_spot.go index e82231832fe..d58eecdc788 100644 --- a/exchanges/gateio/gateio_websocket_request_spot.go +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -129,15 +129,7 @@ func (g *Gateio) WebsocketOrderCancelSpot(ctx context.Context, orderID string, p return nil, currency.ErrCurrencyPairEmpty } - params := &struct { - OrderID string `json:"order_id"` // This requires order_id tag - Pair string `json:"pair"` - Account string `json:"account,omitempty"` - }{ - OrderID: orderID, - Pair: pair.String(), - Account: account, - } + params := &WebsocketOrderRequest{OrderID: orderID, Pair: pair.String(), Account: account} var resp WebsocketOrderResponse err := g.SendWebsocketRequest(ctx, "spot.order_cancel", asset.Spot, params, &resp, 1) @@ -145,7 +137,7 @@ func (g *Gateio) WebsocketOrderCancelSpot(ctx context.Context, orderID string, p } // WebsocketOrderCancelAllByIDsSpot cancels multiple orders via the websocket -func (g *Gateio) WebsocketOrderCancelAllByIDsSpot(ctx context.Context, o []WebsocketOrderCancelRequest) ([]WebsocketCancellAllResponse, error) { +func (g *Gateio) WebsocketOrderCancelAllByIDsSpot(ctx context.Context, o []WebsocketOrderBatchRequest) ([]WebsocketCancellAllResponse, error) { if len(o) == 0 { return nil, errNoOrdersToCancel } @@ -216,15 +208,7 @@ func (g *Gateio) WebsocketGetOrderStatusSpot(ctx context.Context, orderID string return nil, currency.ErrCurrencyPairEmpty } - params := &struct { - OrderID string `json:"order_id"` // This requires order_id tag - Pair string `json:"pair"` - Account string `json:"account,omitempty"` - }{ - OrderID: orderID, - Pair: pair.String(), - Account: account, - } + params := &WebsocketOrderRequest{OrderID: orderID, Pair: pair.String(), Account: account} var resp WebsocketOrderResponse return &resp, g.SendWebsocketRequest(ctx, "spot.order_status", asset.Spot, params, &resp, 1) diff --git a/exchanges/gateio/gateio_websocket_request_spot_test.go b/exchanges/gateio/gateio_websocket_request_spot_test.go index af5d02e8723..3ec059320b3 100644 --- a/exchanges/gateio/gateio_websocket_request_spot_test.go +++ b/exchanges/gateio/gateio_websocket_request_spot_test.go @@ -93,11 +93,11 @@ func TestWebsocketOrderCancelSpot(t *testing.T) { func TestWebsocketOrderCancelAllByIDsSpot(t *testing.T) { t.Parallel() - out := WebsocketOrderCancelRequest{} - _, err := g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderCancelRequest{out}) + out := WebsocketOrderBatchRequest{} + _, err := g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderBatchRequest{out}) require.ErrorIs(t, err, order.ErrOrderIDNotSet) out.OrderID = "1337" - _, err = g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderCancelRequest{out}) + _, err = g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderBatchRequest{out}) require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) out.Pair, err = currency.NewPairFromString("BTC_USDT") @@ -109,7 +109,7 @@ func TestWebsocketOrderCancelAllByIDsSpot(t *testing.T) { g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes out.OrderID = "644913101755" - got, err := g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderCancelRequest{out}) + got, err := g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderBatchRequest{out}) require.NoError(t, err) require.NotEmpty(t, got) } diff --git a/exchanges/gateio/gateio_websocket_request_types.go b/exchanges/gateio/gateio_websocket_request_types.go index 0304bd13f18..b16aaaa0d48 100644 --- a/exchanges/gateio/gateio_websocket_request_types.go +++ b/exchanges/gateio/gateio_websocket_request_types.go @@ -112,13 +112,20 @@ type WebsocketOrderResponse struct { STPAct string `json:"stp_act"` } -// WebsocketOrderCancelRequest defines a websocket order cancel request -type WebsocketOrderCancelRequest struct { +// WebsocketOrderBatchRequest defines a websocket order batch request +type WebsocketOrderBatchRequest struct { OrderID string `json:"id"` // This require id tag not order_id Pair currency.Pair `json:"currency_pair"` Account string `json:"account,omitempty"` } +// WebsocketOrderRequest defines a websocket order request +type WebsocketOrderRequest struct { + OrderID string `json:"order_id"` // This requires order_id tag + Pair string `json:"pair"` + Account string `json:"account,omitempty"` +} + // WebsocketCancellAllResponse defines a websocket order cancel response type WebsocketCancellAllResponse struct { Pair currency.Pair `json:"currency_pair"` diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index 45c7d0d6534..630982f195c 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -26,7 +26,7 @@ type Connection interface { // SendMessageReturnResponse will send a WS message to the connection and wait for response SendMessageReturnResponse(ctx context.Context, epl request.EndpointLimit, signature any, request any) ([]byte, error) // SendMessageReturnResponses will send a WS message to the connection and wait for N responses - SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature any, request any, expected int, isFinalMessage ...Inspector) ([][]byte, error) + SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature any, request any, expected int, messageInspector ...Inspector) ([][]byte, error) // SendRawMessage sends a message over the connection without JSON encoding it SendRawMessage(ctx context.Context, epl request.EndpointLimit, messageType int, message []byte) error // SendJSONMessage sends a JSON encoded message over the connection diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index ae8de076acf..5300bdb6c13 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -303,7 +303,7 @@ func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, epl // SendMessageReturnResponses will send a WS message to the connection and wait for N responses // An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked -func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int, isFinalMessage ...Inspector) ([][]byte, error) { +func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int, messageInspector ...Inspector) ([][]byte, error) { outbound, err := json.Marshal(payload) if err != nil { return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err) @@ -320,7 +320,7 @@ func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, ep return nil, err } - resps, err := w.waitForResponses(ctx, signature, ch, expected, isFinalMessage...) + resps, err := w.waitForResponses(ctx, signature, ch, expected, messageInspector...) if err != nil { return nil, err } @@ -333,7 +333,7 @@ func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, ep } // waitForResponses waits for N responses from a channel -func (w *WebsocketConnection) waitForResponses(ctx context.Context, signature any, ch <-chan []byte, expected int, isFinalMessage ...Inspector) ([][]byte, error) { +func (w *WebsocketConnection) waitForResponses(ctx context.Context, signature any, ch <-chan []byte, expected int, messageInspector ...Inspector) ([][]byte, error) { timeout := time.NewTimer(w.ResponseMaxLimit * time.Duration(expected)) defer timeout.Stop() @@ -343,7 +343,7 @@ func (w *WebsocketConnection) waitForResponses(ctx context.Context, signature an case resp := <-ch: resps = append(resps, resp) // Checks recently received message to determine if this is in fact the final message in a sequence of messages. - if len(isFinalMessage) == 1 && isFinalMessage[0](resp) { + if len(messageInspector) == 1 && messageInspector[0](resp) { w.Match.RemoveSignature(signature) return resps, nil } From 511c78b43522acd6fe8b0b077896070439f2fc0f Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 25 Oct 2024 12:23:23 +1100 Subject: [PATCH 105/138] comparable check for signature --- exchanges/stream/websocket.go | 7 +++++++ exchanges/stream/websocket_test.go | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 023ca46b26b..a7bec78090a 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/url" + "reflect" "slices" "sync" "time" @@ -211,6 +212,8 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { return nil } +var errWrapperDefinedConnectionSignatureNotComparable = errors.New("wrapper defined connection signature is not comparable") + // SetupNewConnection sets up an auth or unauth streaming connection func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { if w == nil { @@ -264,6 +267,10 @@ func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { return fmt.Errorf("%w: %w", errConnSetup, errWebsocketDataHandlerUnset) } + if c.WrapperDefinedConnectionSignature != nil && !reflect.TypeOf(c.WrapperDefinedConnectionSignature).Comparable() { + return errWrapperDefinedConnectionSignatureNotComparable + } + for x := range w.connectionManager { // Below allows for multiple connections to the same URL with different outbound request signatures. This // allows for easier determination of inbound and outbound messages. e.g. Gateio cross_margin, margin on diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index a6bc1646ab3..290899b8d9d 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -1233,6 +1233,11 @@ func TestSetupNewConnection(t *testing.T) { require.ErrorIs(t, err, errWebsocketDataHandlerUnset) connSetup.Handler = func(context.Context, []byte) error { return nil } + connSetup.WrapperDefinedConnectionSignature = []string{"slices are super naughty and not comparable"} + err = multi.SetupNewConnection(connSetup) + require.ErrorIs(t, err, errWrapperDefinedConnectionSignatureNotComparable) + + connSetup.WrapperDefinedConnectionSignature = "comparable string signature" err = multi.SetupNewConnection(connSetup) require.NoError(t, err) From 0a196b883e3447a5ef61de65434da88761b9f59f Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 25 Oct 2024 12:26:37 +1100 Subject: [PATCH 106/138] mv err var --- exchanges/stream/websocket.go | 69 +++++++++++++++++------------------ 1 file changed, 34 insertions(+), 35 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index a7bec78090a..e844436264f 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -37,39 +37,40 @@ var ( // Private websocket errors var ( - errExchangeConfigIsNil = errors.New("exchange config is nil") - errWebsocketIsNil = errors.New("websocket is nil") - errWebsocketSetupIsNil = errors.New("websocket setup is nil") - errWebsocketAlreadyInitialised = errors.New("websocket already initialised") - errWebsocketAlreadyEnabled = errors.New("websocket already enabled") - errWebsocketFeaturesIsUnset = errors.New("websocket features is unset") - errConfigFeaturesIsNil = errors.New("exchange config features is nil") - errDefaultURLIsEmpty = errors.New("default url is empty") - errRunningURLIsEmpty = errors.New("running url cannot be empty") - errInvalidWebsocketURL = errors.New("invalid websocket url") - errExchangeConfigNameEmpty = errors.New("exchange config name empty") - errInvalidTrafficTimeout = errors.New("invalid traffic timeout") - errTrafficAlertNil = errors.New("traffic alert is nil") - errWebsocketSubscriberUnset = errors.New("websocket subscriber function needs to be set") - errWebsocketUnsubscriberUnset = errors.New("websocket unsubscriber functionality allowed but unsubscriber function not set") - errWebsocketConnectorUnset = errors.New("websocket connector function not set") - errWebsocketDataHandlerUnset = errors.New("websocket data handler not set") - errReadMessageErrorsNil = errors.New("read message errors is nil") - errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set") - errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit") - errInvalidMaxSubscriptions = errors.New("max subscriptions cannot be less than 0") - errSameProxyAddress = errors.New("cannot set proxy address to the same address") - errNoConnectFunc = errors.New("websocket connect func not set") - errAlreadyConnected = errors.New("websocket already connected") - errCannotShutdown = errors.New("websocket cannot shutdown") - errAlreadyReconnecting = errors.New("websocket in the process of reconnection") - errConnSetup = errors.New("error in connection setup") - errNoPendingConnections = errors.New("no pending connections, call SetupNewConnection first") - errConnectionWrapperDuplication = errors.New("connection wrapper duplication") - errCannotChangeConnectionURL = errors.New("cannot change connection URL when using multi connection management") - errExchangeConfigEmpty = errors.New("exchange config is empty") - errCannotObtainOutboundConnection = errors.New("cannot obtain outbound connection") - errConnectionSignatureNotSet = errors.New("connection signature not set") + errExchangeConfigIsNil = errors.New("exchange config is nil") + errWebsocketIsNil = errors.New("websocket is nil") + errWebsocketSetupIsNil = errors.New("websocket setup is nil") + errWebsocketAlreadyInitialised = errors.New("websocket already initialised") + errWebsocketAlreadyEnabled = errors.New("websocket already enabled") + errWebsocketFeaturesIsUnset = errors.New("websocket features is unset") + errConfigFeaturesIsNil = errors.New("exchange config features is nil") + errDefaultURLIsEmpty = errors.New("default url is empty") + errRunningURLIsEmpty = errors.New("running url cannot be empty") + errInvalidWebsocketURL = errors.New("invalid websocket url") + errExchangeConfigNameEmpty = errors.New("exchange config name empty") + errInvalidTrafficTimeout = errors.New("invalid traffic timeout") + errTrafficAlertNil = errors.New("traffic alert is nil") + errWebsocketSubscriberUnset = errors.New("websocket subscriber function needs to be set") + errWebsocketUnsubscriberUnset = errors.New("websocket unsubscriber functionality allowed but unsubscriber function not set") + errWebsocketConnectorUnset = errors.New("websocket connector function not set") + errWebsocketDataHandlerUnset = errors.New("websocket data handler not set") + errReadMessageErrorsNil = errors.New("read message errors is nil") + errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set") + errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit") + errInvalidMaxSubscriptions = errors.New("max subscriptions cannot be less than 0") + errSameProxyAddress = errors.New("cannot set proxy address to the same address") + errNoConnectFunc = errors.New("websocket connect func not set") + errAlreadyConnected = errors.New("websocket already connected") + errCannotShutdown = errors.New("websocket cannot shutdown") + errAlreadyReconnecting = errors.New("websocket in the process of reconnection") + errConnSetup = errors.New("error in connection setup") + errNoPendingConnections = errors.New("no pending connections, call SetupNewConnection first") + errConnectionWrapperDuplication = errors.New("connection wrapper duplication") + errCannotChangeConnectionURL = errors.New("cannot change connection URL when using multi connection management") + errExchangeConfigEmpty = errors.New("exchange config is empty") + errCannotObtainOutboundConnection = errors.New("cannot obtain outbound connection") + errConnectionSignatureNotSet = errors.New("connection signature not set") + errWrapperDefinedConnectionSignatureNotComparable = errors.New("wrapper defined connection signature is not comparable") ) var globalReporter Reporter @@ -212,8 +213,6 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { return nil } -var errWrapperDefinedConnectionSignatureNotComparable = errors.New("wrapper defined connection signature is not comparable") - // SetupNewConnection sets up an auth or unauth streaming connection func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { if w == nil { From d3343a7dbde3a935286a248b1bf19c8ed6524b99 Mon Sep 17 00:00:00 2001 From: shazbert Date: Tue, 5 Nov 2024 16:31:37 +1100 Subject: [PATCH 107/138] glorious: nits and stuff --- exchanges/gateio/gateio_websocket.go | 6 +++--- exchanges/gateio/gateio_websocket_request_spot.go | 4 ++-- exchanges/gateio/gateio_websocket_request_spot_test.go | 6 +++--- exchanges/gateio/gateio_wrapper.go | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index cb07d03706c..7a25cf7c997 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -97,9 +97,9 @@ func (g *Gateio) WsConnectSpot(ctx context.Context, conn stream.Connection) erro return nil } -// AuthenticateSpot sends an authentication message to the websocket connection -func (g *Gateio) AuthenticateSpot(ctx context.Context, conn stream.Connection) error { - _, err := g.WebsocketLogin(ctx, conn, "spot.login") +// authenticateSpot sends an authentication message to the websocket connection +func (g *Gateio) authenticateSpot(ctx context.Context, conn stream.Connection) error { + _, err := g.websocketLogin(ctx, conn, "spot.login") return err } diff --git a/exchanges/gateio/gateio_websocket_request_spot.go b/exchanges/gateio/gateio_websocket_request_spot.go index d58eecdc788..80484180c80 100644 --- a/exchanges/gateio/gateio_websocket_request_spot.go +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -27,8 +27,8 @@ var ( errChannelEmpty = errors.New("channel cannot be empty") ) -// WebsocketLogin authenticates the websocket connection -func (g *Gateio) WebsocketLogin(ctx context.Context, conn stream.Connection, channel string) (*WebsocketLoginResponse, error) { +// websocketLogin authenticates the websocket connection +func (g *Gateio) websocketLogin(ctx context.Context, conn stream.Connection, channel string) (*WebsocketLoginResponse, error) { if conn == nil { return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, conn) } diff --git a/exchanges/gateio/gateio_websocket_request_spot_test.go b/exchanges/gateio/gateio_websocket_request_spot_test.go index 3ec059320b3..2243cb16055 100644 --- a/exchanges/gateio/gateio_websocket_request_spot_test.go +++ b/exchanges/gateio/gateio_websocket_request_spot_test.go @@ -18,10 +18,10 @@ import ( func TestWebsocketLogin(t *testing.T) { t.Parallel() - _, err := g.WebsocketLogin(context.Background(), nil, "") + _, err := g.websocketLogin(context.Background(), nil, "") require.ErrorIs(t, err, common.ErrNilPointer) - _, err = g.WebsocketLogin(context.Background(), &stream.WebsocketConnection{}, "") + _, err = g.websocketLogin(context.Background(), &stream.WebsocketConnection{}, "") require.ErrorIs(t, err, errChannelEmpty) sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) @@ -32,7 +32,7 @@ func TestWebsocketLogin(t *testing.T) { demonstrationConn, err := g.Websocket.GetConnection(asset.Spot) require.NoError(t, err) - got, err := g.WebsocketLogin(context.Background(), demonstrationConn, "spot.login") + got, err := g.websocketLogin(context.Background(), demonstrationConn, "spot.login") require.NoError(t, err) require.NotEmpty(t, got) } diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index a84279110de..de960181936 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -218,7 +218,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Unsubscriber: g.Unsubscribe, GenerateSubscriptions: g.generateSubscriptionsSpot, Connector: g.WsConnectSpot, - Authenticate: g.AuthenticateSpot, + Authenticate: g.authenticateSpot, WrapperDefinedConnectionSignature: asset.Spot, BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) From 2872e055d827d7fc15a03ad1115f0ed6b2d53d9e Mon Sep 17 00:00:00 2001 From: shazbert Date: Fri, 8 Nov 2024 10:59:37 +1100 Subject: [PATCH 108/138] attempt to fix race --- exchanges/gateio/gateio_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index 477e0d1616f..5cd0c74295d 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -3011,11 +3011,11 @@ func TestGetSettlementFromCurrency(t *testing.T) { for _, assetType := range []asset.Item{asset.Futures, asset.DeliveryFutures, asset.Options} { availPairs, err := g.GetAvailablePairs(assetType) require.NoErrorf(t, err, "GetAvailablePairs for asset %s must not error", assetType) - for x := range availPairs { - t.Run(strconv.Itoa(x), func(t *testing.T) { + for x, pair := range availPairs { + t.Run(strconv.Itoa(x)+":"+assetType.String(), func(t *testing.T) { t.Parallel() - _, err = getSettlementFromCurrency(availPairs[x]) - assert.NoErrorf(t, err, "getSettlementFromCurrency should not error for pair %s and asset %s", availPairs[x], assetType) + _, err := getSettlementFromCurrency(pair) + assert.NoErrorf(t, err, "getSettlementFromCurrency should not error for pair %s and asset %s", pair, assetType) }) } } From 2b79a9c2a024ab2ee61a211134c282e6ad3af5fb Mon Sep 17 00:00:00 2001 From: shazbert Date: Fri, 15 Nov 2024 11:20:53 +1100 Subject: [PATCH 109/138] glorious: nits --- exchanges/gateio/gateio_websocket.go | 4 ++-- .../gateio/gateio_websocket_request_spot.go | 16 +++++----------- exchanges/stream/stream_types.go | 4 +++- exchanges/stream/websocket.go | 9 ++++----- exchanges/stream/websocket_connection.go | 5 +++++ exchanges/stream/websocket_test.go | 13 +++++++++---- 6 files changed, 28 insertions(+), 23 deletions(-) diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 7a25cf7c997..ce733e3767c 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -121,14 +121,14 @@ func (g *Gateio) WsHandleSpotData(_ context.Context, respRaw []byte) error { if push.RequestID != "" { if !g.Websocket.Match.IncomingWithData(push.RequestID, respRaw) { - return fmt.Errorf("gateio_websocket.go error - unable to match requestID %v", push.RequestID) + return fmt.Errorf("%w for requestID %v", stream.ErrNoMessageListener, push.RequestID) } return nil } if push.Event == subscribeEvent || push.Event == unsubscribeEvent { if !g.Websocket.Match.IncomingWithData(push.ID, respRaw) { - return fmt.Errorf("couldn't match subscription message with ID: %d", push.ID) + return fmt.Errorf("%w couldn't match subscription message with ID: %d", stream.ErrNoMessageListener, push.ID) } return nil } diff --git a/exchanges/gateio/gateio_websocket_request_spot.go b/exchanges/gateio/gateio_websocket_request_spot.go index 80484180c80..33c19a5b7c9 100644 --- a/exchanges/gateio/gateio_websocket_request_spot.go +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -83,7 +83,6 @@ func (g *Gateio) websocketLogin(ctx context.Context, conn stream.Connection, cha // WebsocketOrderPlaceSpot places an order via the websocket connection. You can // send multiple orders in a single request. But only for one asset route. -// So this can only batch spot orders or futures orders, not both. func (g *Gateio) WebsocketOrderPlaceSpot(ctx context.Context, batch []WebsocketOrder) ([]WebsocketOrderResponse, error) { if len(batch) == 0 { return nil, errBatchSliceEmpty @@ -111,13 +110,10 @@ func (g *Gateio) WebsocketOrderPlaceSpot(ctx context.Context, batch []WebsocketO if len(batch) == 1 { var singleResponse WebsocketOrderResponse - err := g.SendWebsocketRequest(ctx, "spot.order_place", asset.Spot, batch[0], &singleResponse, 2) - return []WebsocketOrderResponse{singleResponse}, err + return []WebsocketOrderResponse{singleResponse}, g.SendWebsocketRequest(ctx, "spot.order_place", asset.Spot, batch[0], &singleResponse, 2) } - var resp []WebsocketOrderResponse - err := g.SendWebsocketRequest(ctx, "spot.order_place", asset.Spot, batch, &resp, 2) - return resp, err + return resp, g.SendWebsocketRequest(ctx, "spot.order_place", asset.Spot, batch, &resp, 2) } // WebsocketOrderCancelSpot cancels an order via the websocket connection @@ -132,8 +128,7 @@ func (g *Gateio) WebsocketOrderCancelSpot(ctx context.Context, orderID string, p params := &WebsocketOrderRequest{OrderID: orderID, Pair: pair.String(), Account: account} var resp WebsocketOrderResponse - err := g.SendWebsocketRequest(ctx, "spot.order_cancel", asset.Spot, params, &resp, 1) - return &resp, err + return &resp, g.SendWebsocketRequest(ctx, "spot.order_cancel", asset.Spot, params, &resp, 1) } // WebsocketOrderCancelAllByIDsSpot cancels multiple orders via the websocket @@ -152,8 +147,7 @@ func (g *Gateio) WebsocketOrderCancelAllByIDsSpot(ctx context.Context, o []Webso } var resp []WebsocketCancellAllResponse - err := g.SendWebsocketRequest(ctx, "spot.order_cancel_ids", asset.Spot, o, &resp, 2) - return resp, err + return resp, g.SendWebsocketRequest(ctx, "spot.order_cancel_ids", asset.Spot, o, &resp, 2) } // WebsocketOrderCancelAllByPairSpot cancels all orders for a specific pair @@ -246,7 +240,7 @@ func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel string, connS } if len(responses) == 0 { - return errors.New("no responses received") + return common.ErrNoResponse } var inbound WebsocketAPIResponse diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index 630982f195c..e6d0622eb05 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -37,7 +37,9 @@ type Connection interface { Shutdown() error } -// Inspector is a hook that allows for custom message inspection +// Inspector is used to verify messages via SendMessageReturnResponse +// Only one can used +// It inspects the []bytes websocket message and returns true if it is the appropriate message to action type Inspector func([]byte) bool // Response defines generalised data from the stream connection diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index e844436264f..1976c8b8f3f 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -1278,15 +1278,14 @@ func (w *Websocket) GetConnection(connSignature any) (Connection, error) { w.m.Lock() defer w.m.Unlock() - if !w.IsConnected() { - return nil, ErrNotConnected - } - if !w.useMultiConnectionManagement { return nil, fmt.Errorf("%s: multi connection management not enabled %w please use exported Conn and AuthConn fields", w.exchangeName, errCannotObtainOutboundConnection) } - // Opted to range and not have a map, as connection level wrappers will be limited. + if !w.IsConnected() { + return nil, ErrNotConnected + } + for _, wrapper := range w.connectionManager { if wrapper.Setup.WrapperDefinedConnectionSignature == connSignature { if wrapper.Connection == nil { diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 5300bdb6c13..9ccf43c0f91 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -28,6 +28,7 @@ var ( errConnectionFault = errors.New("connection fault") errWebsocketIsDisconnected = errors.New("websocket connection is disconnected") errRateLimitNotFound = errors.New("rate limit definition not found") + errOnlyOneMessageInspector = errors.New("only one message inspector can be used") ) // Dial sets proxy urls and then connects to the websocket @@ -334,6 +335,10 @@ func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, ep // waitForResponses waits for N responses from a channel func (w *WebsocketConnection) waitForResponses(ctx context.Context, signature any, ch <-chan []byte, expected int, messageInspector ...Inspector) ([][]byte, error) { + if len(messageInspector) > 1 { + return nil, errOnlyOneMessageInspector + } + timeout := time.NewTimer(w.ResponseMaxLimit * time.Duration(expected)) defer timeout.Stop() diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 290899b8d9d..cf935c3b219 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -224,7 +224,6 @@ func TestConnectionMessageErrors(t *testing.T) { ws.useMultiConnectionManagement = true ws.SetCanUseAuthenticatedEndpoints(true) - ws.verbose = true // NOTE: Intentional mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) defer mock.Close() @@ -765,6 +764,10 @@ func TestSendMessageReturnResponse(t *testing.T) { wc.ResponseMaxLimit = 1 _, err = wc.SendMessageReturnResponse(context.Background(), request.Unset, "123", req) assert.ErrorIs(t, err, ErrSignatureTimeout, "SendMessageReturnResponse should error when request ID not found") + + inspector := func(b []byte) bool { return false } + _, err = wc.SendMessageReturnResponses(context.Background(), request.Unset, "123", req, 1, inspector, inspector) + assert.ErrorIs(t, err, errOnlyOneMessageInspector) } type reporter struct { @@ -1505,14 +1508,16 @@ func TestGetConnection(t *testing.T) { _, err = ws.GetConnection(nil) require.ErrorIs(t, err, errConnectionSignatureNotSet) + _, err = ws.GetConnection("testURL") + require.ErrorIs(t, err, errCannotObtainOutboundConnection) + + ws.useMultiConnectionManagement = true + _, err = ws.GetConnection("testURL") require.ErrorIs(t, err, ErrNotConnected) ws.setState(connectedState) - _, err = ws.GetConnection("testURL") - require.ErrorIs(t, err, errCannotObtainOutboundConnection) - ws.useMultiConnectionManagement = true _, err = ws.GetConnection("testURL") require.ErrorIs(t, err, ErrRequestRouteNotFound) From 1696d08e6b119cf003100a06e208e55326bda53a Mon Sep 17 00:00:00 2001 From: shazbert Date: Thu, 21 Nov 2024 08:44:58 +1100 Subject: [PATCH 110/138] gk: nits; engine log cleanup --- engine/engine.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/engine/engine.go b/engine/engine.go index 27e11219a25..878b8a9fedc 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -791,10 +791,11 @@ func (bot *Engine) LoadExchange(name string) error { localWG.Wait() if !bot.Settings.EnableExchangeHTTPRateLimiter { - gctlog.Warnf(gctlog.ExchangeSys, "Loaded exchange %s rate limiting has been turned off.\n", exch.GetName()) err = exch.DisableRateLimiter() if err != nil { - gctlog.Errorf(gctlog.ExchangeSys, "Loaded exchange %s rate limiting cannot be turned off: %s.\n", exch.GetName(), err) + gctlog.Errorf(gctlog.ExchangeSys, "error disabling rate limiter for %s: %v", exch.GetName(), err) + } else { + gctlog.Warnf(gctlog.ExchangeSys, "Loaded exchange %s rate limiting has been turned off\n", exch.GetName()) } } @@ -817,7 +818,7 @@ func (bot *Engine) LoadExchange(name string) error { if b.API.AuthenticatedSupport || b.API.AuthenticatedWebsocketSupport { err = exch.ValidateAPICredentials(context.TODO(), asset.Spot) if err != nil { - gctlog.Warnf(gctlog.ExchangeSys, "%s: Cannot validate credentials, authenticated support has been disabled, Error: %s", b.Name, err) + gctlog.Warnf(gctlog.ExchangeSys, "%s: Error validating credentials: %v", b.Name, err) b.API.AuthenticatedSupport = false b.API.AuthenticatedWebsocketSupport = false exchCfg.API.AuthenticatedSupport = false From cb2a2351f1ea244227848aae3869d72765f2f15c Mon Sep 17 00:00:00 2001 From: shazbert Date: Thu, 21 Nov 2024 08:47:55 +1100 Subject: [PATCH 111/138] gk: nits; OCD --- exchanges/gateio/gateio_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index 5cd0c74295d..c6e95e956e3 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -3011,8 +3011,8 @@ func TestGetSettlementFromCurrency(t *testing.T) { for _, assetType := range []asset.Item{asset.Futures, asset.DeliveryFutures, asset.Options} { availPairs, err := g.GetAvailablePairs(assetType) require.NoErrorf(t, err, "GetAvailablePairs for asset %s must not error", assetType) - for x, pair := range availPairs { - t.Run(strconv.Itoa(x)+":"+assetType.String(), func(t *testing.T) { + for i, pair := range availPairs { + t.Run(strconv.Itoa(i)+":"+assetType.String(), func(t *testing.T) { t.Parallel() _, err := getSettlementFromCurrency(pair) assert.NoErrorf(t, err, "getSettlementFromCurrency should not error for pair %s and asset %s", pair, assetType) From 78e974a32ce642e82763142781487e54abc6f86e Mon Sep 17 00:00:00 2001 From: shazbert Date: Thu, 21 Nov 2024 09:10:23 +1100 Subject: [PATCH 112/138] gk: nits; move function change file names --- exchanges/gateio/gateio_websocket.go | 54 +++++++++++++++++ ...o => gateio_websocket_delivery_futures.go} | 0 ...futures.go => gateio_websocket_futures.go} | 0 ...s_option.go => gateio_websocket_option.go} | 0 .../gateio/gateio_websocket_request_spot.go | 58 ------------------- 5 files changed, 54 insertions(+), 58 deletions(-) rename exchanges/gateio/{gateio_ws_delivery_futures.go => gateio_websocket_delivery_futures.go} (100%) rename exchanges/gateio/{gateio_ws_futures.go => gateio_websocket_futures.go} (100%) rename exchanges/gateio/{gateio_ws_option.go => gateio_websocket_option.go} (100%) diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index ce733e3767c..4044fdecd50 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -103,6 +103,60 @@ func (g *Gateio) authenticateSpot(ctx context.Context, conn stream.Connection) e return err } +// websocketLogin authenticates the websocket connection +func (g *Gateio) websocketLogin(ctx context.Context, conn stream.Connection, channel string) (*WebsocketLoginResponse, error) { + if conn == nil { + return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, conn) + } + + if channel == "" { + return nil, errChannelEmpty + } + + creds, err := g.GetCredentials(ctx) + if err != nil { + return nil, err + } + + tn := time.Now().Unix() + msg := "api\n" + channel + "\n" + "\n" + strconv.FormatInt(tn, 10) + mac := hmac.New(sha512.New, []byte(creds.Secret)) + if _, err = mac.Write([]byte(msg)); err != nil { + return nil, err + } + signature := hex.EncodeToString(mac.Sum(nil)) + + payload := WebsocketPayload{ + RequestID: strconv.FormatInt(conn.GenerateMessageID(false), 10), + APIKey: creds.Key, + Signature: signature, + Timestamp: strconv.FormatInt(tn, 10), + } + + req := WebsocketRequest{Time: tn, Channel: channel, Event: "api", Payload: payload} + + resp, err := conn.SendMessageReturnResponse(ctx, request.Unset, req.Payload.RequestID, req) + if err != nil { + return nil, err + } + + var inbound WebsocketAPIResponse + if err := json.Unmarshal(resp, &inbound); err != nil { + return nil, err + } + + if inbound.Header.Status != "200" { + var wsErr WebsocketErrors + if err := json.Unmarshal(inbound.Data, &wsErr.Errors); err != nil { + return nil, err + } + return nil, fmt.Errorf("%s: %s", wsErr.Errors.Label, wsErr.Errors.Message) + } + + var result WebsocketLoginResponse + return &result, json.Unmarshal(inbound.Data, &result) +} + func (g *Gateio) generateWsSignature(secret, event, channel string, t int64) (string, error) { msg := "channel=" + channel + "&event=" + event + "&time=" + strconv.FormatInt(t, 10) mac := hmac.New(sha512.New, []byte(secret)) diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_websocket_delivery_futures.go similarity index 100% rename from exchanges/gateio/gateio_ws_delivery_futures.go rename to exchanges/gateio/gateio_websocket_delivery_futures.go diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_websocket_futures.go similarity index 100% rename from exchanges/gateio/gateio_ws_futures.go rename to exchanges/gateio/gateio_websocket_futures.go diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_websocket_option.go similarity index 100% rename from exchanges/gateio/gateio_ws_option.go rename to exchanges/gateio/gateio_websocket_option.go diff --git a/exchanges/gateio/gateio_websocket_request_spot.go b/exchanges/gateio/gateio_websocket_request_spot.go index 33c19a5b7c9..0e492cd660f 100644 --- a/exchanges/gateio/gateio_websocket_request_spot.go +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -2,9 +2,6 @@ package gateio import ( "context" - "crypto/hmac" - "crypto/sha512" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -17,7 +14,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" ) var ( @@ -27,60 +23,6 @@ var ( errChannelEmpty = errors.New("channel cannot be empty") ) -// websocketLogin authenticates the websocket connection -func (g *Gateio) websocketLogin(ctx context.Context, conn stream.Connection, channel string) (*WebsocketLoginResponse, error) { - if conn == nil { - return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, conn) - } - - if channel == "" { - return nil, errChannelEmpty - } - - creds, err := g.GetCredentials(ctx) - if err != nil { - return nil, err - } - - tn := time.Now().Unix() - msg := "api\n" + channel + "\n" + "\n" + strconv.FormatInt(tn, 10) - mac := hmac.New(sha512.New, []byte(creds.Secret)) - if _, err = mac.Write([]byte(msg)); err != nil { - return nil, err - } - signature := hex.EncodeToString(mac.Sum(nil)) - - payload := WebsocketPayload{ - RequestID: strconv.FormatInt(conn.GenerateMessageID(false), 10), - APIKey: creds.Key, - Signature: signature, - Timestamp: strconv.FormatInt(tn, 10), - } - - req := WebsocketRequest{Time: tn, Channel: channel, Event: "api", Payload: payload} - - resp, err := conn.SendMessageReturnResponse(ctx, request.Unset, req.Payload.RequestID, req) - if err != nil { - return nil, err - } - - var inbound WebsocketAPIResponse - if err := json.Unmarshal(resp, &inbound); err != nil { - return nil, err - } - - if inbound.Header.Status != "200" { - var wsErr WebsocketErrors - if err := json.Unmarshal(inbound.Data, &wsErr.Errors); err != nil { - return nil, err - } - return nil, fmt.Errorf("%s: %s", wsErr.Errors.Label, wsErr.Errors.Message) - } - - var result WebsocketLoginResponse - return &result, json.Unmarshal(inbound.Data, &result) -} - // WebsocketOrderPlaceSpot places an order via the websocket connection. You can // send multiple orders in a single request. But only for one asset route. func (g *Gateio) WebsocketOrderPlaceSpot(ctx context.Context, batch []WebsocketOrder) ([]WebsocketOrderResponse, error) { From 00a267d16fb848a95ffcc6f9732f250f61718b04 Mon Sep 17 00:00:00 2001 From: shazbert Date: Thu, 21 Nov 2024 10:12:15 +1100 Subject: [PATCH 113/138] gk: nits; :rocket: --- .../gateio/gateio_websocket_request_spot.go | 54 +++++++++---------- .../gateio_websocket_request_spot_test.go | 4 +- 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/exchanges/gateio/gateio_websocket_request_spot.go b/exchanges/gateio/gateio_websocket_request_spot.go index 0e492cd660f..580562ca3bf 100644 --- a/exchanges/gateio/gateio_websocket_request_spot.go +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -17,45 +17,43 @@ import ( ) var ( - errBatchSliceEmpty = errors.New("batch cannot be empty") + errOrdersEmpty = errors.New("orders cannot be empty") errNoOrdersToCancel = errors.New("no orders to cancel") - errEdgeCaseIssue = errors.New("edge case issue") errChannelEmpty = errors.New("channel cannot be empty") ) // WebsocketOrderPlaceSpot places an order via the websocket connection. You can // send multiple orders in a single request. But only for one asset route. -func (g *Gateio) WebsocketOrderPlaceSpot(ctx context.Context, batch []WebsocketOrder) ([]WebsocketOrderResponse, error) { - if len(batch) == 0 { - return nil, errBatchSliceEmpty +func (g *Gateio) WebsocketOrderPlaceSpot(ctx context.Context, orders []WebsocketOrder) ([]WebsocketOrderResponse, error) { + if len(orders) == 0 { + return nil, errOrdersEmpty } - for i := range batch { - if batch[i].Text == "" { - // For some reason the API requires a text field, or it will be - // rejected in the second response. This is a workaround. - batch[i].Text = "t-" + strconv.FormatInt(g.Counter.IncrementAndGet(), 10) + for i := range orders { + if orders[i].Text == "" { + // API requires Text field, or it will be rejected + orders[i].Text = "t-" + strconv.FormatInt(g.Counter.IncrementAndGet(), 10) } - if batch[i].CurrencyPair == "" { + if orders[i].CurrencyPair == "" { return nil, currency.ErrCurrencyPairEmpty } - if batch[i].Side == "" { + if orders[i].Side == "" { return nil, order.ErrSideIsInvalid } - if batch[i].Amount == "" { + if orders[i].Amount == "" { return nil, errInvalidAmount } - if batch[i].Type == "limit" && batch[i].Price == "" { + if orders[i].Type == "limit" && orders[i].Price == "" { return nil, errInvalidPrice } } - if len(batch) == 1 { + if len(orders) == 1 { var singleResponse WebsocketOrderResponse - return []WebsocketOrderResponse{singleResponse}, g.SendWebsocketRequest(ctx, "spot.order_place", asset.Spot, batch[0], &singleResponse, 2) + return []WebsocketOrderResponse{singleResponse}, g.SendWebsocketRequest(ctx, "spot.order_place", asset.Spot, orders[0], &singleResponse, 2) } var resp []WebsocketOrderResponse - return resp, g.SendWebsocketRequest(ctx, "spot.order_place", asset.Spot, batch, &resp, 2) + return resp, g.SendWebsocketRequest(ctx, "spot.order_place", asset.Spot, orders, &resp, 2) } // WebsocketOrderCancelSpot cancels an order via the websocket connection @@ -95,7 +93,8 @@ func (g *Gateio) WebsocketOrderCancelAllByIDsSpot(ctx context.Context, o []Webso // WebsocketOrderCancelAllByPairSpot cancels all orders for a specific pair func (g *Gateio) WebsocketOrderCancelAllByPairSpot(ctx context.Context, pair currency.Pair, side order.Side, account string) ([]WebsocketOrderResponse, error) { if !pair.IsEmpty() && side == order.UnknownSide { - return nil, fmt.Errorf("%w: side cannot be unknown when pair is set as this will purge *ALL* open orders", errEdgeCaseIssue) + // This case will cancel all orders for every pair, this can be introduced later + return nil, fmt.Errorf("'%v' %w while pair is set", side, order.ErrSideIsInvalid) } sideStr := "" @@ -150,6 +149,11 @@ func (g *Gateio) WebsocketGetOrderStatusSpot(ctx context.Context, orderID string return &resp, g.SendWebsocketRequest(ctx, "spot.order_status", asset.Spot, params, &resp, 1) } +// funnelResult is used to unmarshal the result of a websocket request back to the required caller type +type funnelResult struct { + Result any `json:"result"` +} + // SendWebsocketRequest sends a websocket request to the exchange func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel string, connSignature, params, result any, expectedResponses int) error { paramPayload, err := json.Marshal(params) @@ -176,7 +180,7 @@ func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel string, connS }, } - responses, err := conn.SendMessageReturnResponses(ctx, request.Unset, req.Payload.RequestID, req, expectedResponses, InspectPayloadForAck) + responses, err := conn.SendMessageReturnResponses(ctx, request.Unset, req.Payload.RequestID, req, expectedResponses, inspectPayloadForAck) if err != nil { return err } @@ -203,18 +207,12 @@ func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel string, connS return fmt.Errorf("%s: %s", wsErr.Errors.Label, wsErr.Errors.Message) } - to := struct { - Result any `json:"result"` - }{ - Result: result, - } - - return json.Unmarshal(inbound.Data, &to) + return json.Unmarshal(inbound.Data, &funnelResult{Result: result}) } -// InspectPayloadForAck checks the payload for an ack, it returns true if the +// inspectPayloadForAck checks the payload for an ack, it returns true if the // payload does not contain an ack. This will force the cancellation of further // waiting for responses. -func InspectPayloadForAck(data []byte) bool { +func inspectPayloadForAck(data []byte) bool { return !strings.Contains(string(data), "ack") } diff --git a/exchanges/gateio/gateio_websocket_request_spot_test.go b/exchanges/gateio/gateio_websocket_request_spot_test.go index 2243cb16055..004806f75e5 100644 --- a/exchanges/gateio/gateio_websocket_request_spot_test.go +++ b/exchanges/gateio/gateio_websocket_request_spot_test.go @@ -40,7 +40,7 @@ func TestWebsocketLogin(t *testing.T) { func TestWebsocketOrderPlaceSpot(t *testing.T) { t.Parallel() _, err := g.WebsocketOrderPlaceSpot(context.Background(), nil) - require.ErrorIs(t, err, errBatchSliceEmpty) + require.ErrorIs(t, err, errOrdersEmpty) _, err = g.WebsocketOrderPlaceSpot(context.Background(), make([]WebsocketOrder, 1)) require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) out := WebsocketOrder{CurrencyPair: "BTC_USDT"} @@ -120,7 +120,7 @@ func TestWebsocketOrderCancelAllByPairSpot(t *testing.T) { require.NoError(t, err) _, err = g.WebsocketOrderCancelAllByPairSpot(context.Background(), pair, 0, "") - require.ErrorIs(t, err, errEdgeCaseIssue) + require.ErrorIs(t, err, order.ErrSideIsInvalid) sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) From 72f97dd69a6b742fc345217dde7626a09490288b Mon Sep 17 00:00:00 2001 From: shazbert Date: Thu, 21 Nov 2024 10:40:56 +1100 Subject: [PATCH 114/138] gk: nits; convert variadic function and message inspection to interface and include a specific function for that handling so as to not need nil on every call --- .../gateio/gateio_websocket_request_spot.go | 6 +++--- exchanges/stream/stream_types.go | 11 +++++++---- exchanges/stream/websocket_connection.go | 19 ++++++++++--------- exchanges/stream/websocket_test.go | 9 ++++++--- 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/exchanges/gateio/gateio_websocket_request_spot.go b/exchanges/gateio/gateio_websocket_request_spot.go index 580562ca3bf..e0011a6fcc5 100644 --- a/exchanges/gateio/gateio_websocket_request_spot.go +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -180,7 +180,7 @@ func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel string, connS }, } - responses, err := conn.SendMessageReturnResponses(ctx, request.Unset, req.Payload.RequestID, req, expectedResponses, inspectPayloadForAck) + responses, err := conn.SendMessageReturnResponsesWithInspector(ctx, request.Unset, req.Payload.RequestID, req, expectedResponses, g) if err != nil { return err } @@ -210,9 +210,9 @@ func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel string, connS return json.Unmarshal(inbound.Data, &funnelResult{Result: result}) } -// inspectPayloadForAck checks the payload for an ack, it returns true if the +// Inspect checks the payload for an ack, it returns true if the // payload does not contain an ack. This will force the cancellation of further // waiting for responses. -func inspectPayloadForAck(data []byte) bool { +func (g *Gateio) Inspect(data []byte) bool { return !strings.Contains(string(data), "ack") } diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index e6d0622eb05..3a7006fddc4 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -26,7 +26,9 @@ type Connection interface { // SendMessageReturnResponse will send a WS message to the connection and wait for response SendMessageReturnResponse(ctx context.Context, epl request.EndpointLimit, signature any, request any) ([]byte, error) // SendMessageReturnResponses will send a WS message to the connection and wait for N responses - SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature any, request any, expected int, messageInspector ...Inspector) ([][]byte, error) + SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature any, request any, expected int) ([][]byte, error) + // SendMessageReturnResponsesWithInspector will send a WS message to the connection and wait for N responses with message inspection + SendMessageReturnResponsesWithInspector(ctx context.Context, epl request.EndpointLimit, signature any, request any, expected int, messageInspector Inspector) ([][]byte, error) // SendRawMessage sends a message over the connection without JSON encoding it SendRawMessage(ctx context.Context, epl request.EndpointLimit, messageType int, message []byte) error // SendJSONMessage sends a JSON encoded message over the connection @@ -37,10 +39,11 @@ type Connection interface { Shutdown() error } -// Inspector is used to verify messages via SendMessageReturnResponse -// Only one can used +// Inspector is used to verify messages via SendMessageReturnResponsesWithInspection // It inspects the []bytes websocket message and returns true if it is the appropriate message to action -type Inspector func([]byte) bool +type Inspector interface { + Inspect([]byte) bool +} // Response defines generalised data from the stream connection type Response struct { diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 9ccf43c0f91..10cd8dc8828 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -28,7 +28,6 @@ var ( errConnectionFault = errors.New("connection fault") errWebsocketIsDisconnected = errors.New("websocket connection is disconnected") errRateLimitNotFound = errors.New("rate limit definition not found") - errOnlyOneMessageInspector = errors.New("only one message inspector can be used") ) // Dial sets proxy urls and then connects to the websocket @@ -304,7 +303,13 @@ func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, epl // SendMessageReturnResponses will send a WS message to the connection and wait for N responses // An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked -func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int, messageInspector ...Inspector) ([][]byte, error) { +func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int) ([][]byte, error) { + return w.SendMessageReturnResponsesWithInspector(ctx, epl, signature, payload, expected, nil) +} + +// SendMessageReturnResponsesWithInspector will send a WS message to the connection and wait for N responses +// An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked +func (w *WebsocketConnection) SendMessageReturnResponsesWithInspector(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int, messageInspector Inspector) ([][]byte, error) { outbound, err := json.Marshal(payload) if err != nil { return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err) @@ -321,7 +326,7 @@ func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, ep return nil, err } - resps, err := w.waitForResponses(ctx, signature, ch, expected, messageInspector...) + resps, err := w.waitForResponses(ctx, signature, ch, expected, messageInspector) if err != nil { return nil, err } @@ -334,11 +339,7 @@ func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, ep } // waitForResponses waits for N responses from a channel -func (w *WebsocketConnection) waitForResponses(ctx context.Context, signature any, ch <-chan []byte, expected int, messageInspector ...Inspector) ([][]byte, error) { - if len(messageInspector) > 1 { - return nil, errOnlyOneMessageInspector - } - +func (w *WebsocketConnection) waitForResponses(ctx context.Context, signature any, ch <-chan []byte, expected int, messageInspector Inspector) ([][]byte, error) { timeout := time.NewTimer(w.ResponseMaxLimit * time.Duration(expected)) defer timeout.Stop() @@ -348,7 +349,7 @@ func (w *WebsocketConnection) waitForResponses(ctx context.Context, signature an case resp := <-ch: resps = append(resps, resp) // Checks recently received message to determine if this is in fact the final message in a sequence of messages. - if len(messageInspector) == 1 && messageInspector[0](resp) { + if messageInspector != nil && messageInspector.Inspect(resp) { w.Match.RemoveSignature(signature) return resps, nil } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index cf935c3b219..3053a0fa840 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -765,11 +765,14 @@ func TestSendMessageReturnResponse(t *testing.T) { _, err = wc.SendMessageReturnResponse(context.Background(), request.Unset, "123", req) assert.ErrorIs(t, err, ErrSignatureTimeout, "SendMessageReturnResponse should error when request ID not found") - inspector := func(b []byte) bool { return false } - _, err = wc.SendMessageReturnResponses(context.Background(), request.Unset, "123", req, 1, inspector, inspector) - assert.ErrorIs(t, err, errOnlyOneMessageInspector) + _, err = wc.SendMessageReturnResponsesWithInspector(context.Background(), request.Unset, "123", req, 1, inspection{}) + assert.ErrorIs(t, err, ErrSignatureTimeout, "SendMessageReturnResponse should error when request ID not found") } +type inspection struct{} + +func (i inspection) Inspect([]byte) bool { return false } + type reporter struct { name string msg []byte From 274ae8fce02f4fbb72074fdc7a5e97ddfd9858ae Mon Sep 17 00:00:00 2001 From: shazbert Date: Thu, 21 Nov 2024 10:50:06 +1100 Subject: [PATCH 115/138] gk: nits; continued --- exchanges/stream/stream_types.go | 5 +--- exchanges/stream/websocket.go | 36 ++++++++++++++--------------- exchanges/stream/websocket_test.go | 6 ++--- exchanges/stream/websocket_types.go | 4 ++-- 4 files changed, 24 insertions(+), 27 deletions(-) diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index 3a7006fddc4..bf72a42afda 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -84,10 +84,7 @@ type ConnectionSetup struct { // This is useful for when an exchange connection requires a unique or // structured message ID for each message sent. BespokeGenerateMessageID func(highPrecision bool) int64 - // Authenticate is a function that will be called to authenticate the - // connection to the exchange's websocket server. This function should - // handle the authentication process and return an error if the - // authentication fails. + // Authenticate will be called to authenticate the connection Authenticate func(ctx context.Context, conn Connection) error // WrapperDefinedConnectionSignature is any type that will match to a specific connection. This could be an asset // type `asset.Spot`, a string type denoting the individual URL, an authenticated or unauthenticated string or a diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 1976c8b8f3f..18bed4ee457 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -92,12 +92,12 @@ func NewWebsocket() *Websocket { // after subscriptions are made but before the connectionMonitor has // started. This allows the error to be read and handled in the // connectionMonitor and start a connection cycle again. - ReadMessageErrors: make(chan error, 1), - Match: NewMatch(), - subscriptions: subscription.NewStore(), - features: &protocol.Features{}, - Orderbook: buffer.Orderbook{}, - connectionToWrapper: make(map[Connection]*ConnectionWrapper), + ReadMessageErrors: make(chan error, 1), + Match: NewMatch(), + subscriptions: subscription.NewStore(), + features: &protocol.Features{}, + Orderbook: buffer.Orderbook{}, + connections: make(map[Connection]*ConnectionWrapper), } } @@ -435,7 +435,7 @@ func (w *Websocket) connect() error { break } - w.connectionToWrapper[conn] = w.connectionManager[i] + w.connections[conn] = w.connectionManager[i] w.connectionManager[i].Connection = conn w.Wg.Add(1) @@ -476,7 +476,7 @@ func (w *Websocket) connect() error { } w.connectionManager[x].Subscriptions.Clear() } - clear(w.connectionToWrapper) + clear(w.connections) w.setState(disconnectedState) // Flip from connecting to disconnected. // Drain residual error in the single buffered channel, this mitigates @@ -564,7 +564,7 @@ func (w *Websocket) shutdown() error { } } // Clean map of old connections - clear(w.connectionToWrapper) + clear(w.connections) if w.Conn != nil { if err := w.Conn.Shutdown(); err != nil { @@ -655,7 +655,7 @@ func (w *Websocket) FlushChannels() error { } w.Wg.Add(1) go w.Reader(context.TODO(), conn, w.connectionManager[x].Setup.Handler) - w.connectionToWrapper[conn] = w.connectionManager[x] + w.connections[conn] = w.connectionManager[x] w.connectionManager[x].Connection = conn } @@ -674,7 +674,7 @@ func (w *Websocket) FlushChannels() error { // If there are no subscriptions to subscribe to, close the connection as it is no longer needed. if w.connectionManager[x].Subscriptions.Len() == 0 { - delete(w.connectionToWrapper, w.connectionManager[x].Connection) // Remove from lookup map + delete(w.connections, w.connectionManager[x].Connection) // Remove from lookup map if err := w.connectionManager[x].Connection.Shutdown(); err != nil { log.Warnf(log.WebsocketMgr, "%v websocket: failed to shutdown connection: %v", w.exchangeName, err) } @@ -835,7 +835,7 @@ func (w *Websocket) GetName() string { // and the new subscription list when pairs are disabled or enabled. func (w *Websocket) GetChannelDifference(conn Connection, newSubs subscription.List) (sub, unsub subscription.List) { var subscriptionStore **subscription.Store - if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { + if wrapper, ok := w.connections[conn]; ok && conn != nil { subscriptionStore = &wrapper.Subscriptions } else { subscriptionStore = &w.subscriptions @@ -851,7 +851,7 @@ func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.L if len(channels) == 0 { return nil // No channels to unsubscribe from is not an error } - if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { + if wrapper, ok := w.connections[conn]; ok && conn != nil { return w.unsubscribe(wrapper.Subscriptions, channels, func(channels subscription.List) error { return wrapper.Setup.Unsubscriber(context.TODO(), conn, channels) }) @@ -897,7 +897,7 @@ func (w *Websocket) SubscribeToChannels(conn Connection, subs subscription.List) return err } - if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { + if wrapper, ok := w.connections[conn]; ok && conn != nil { return wrapper.Setup.Subscriber(context.TODO(), conn, subs) } @@ -918,7 +918,7 @@ func (w *Websocket) AddSubscriptions(conn Connection, subs ...*subscription.Subs return fmt.Errorf("%w: AddSubscriptions called on nil Websocket", common.ErrNilPointer) } var subscriptionStore **subscription.Store - if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { + if wrapper, ok := w.connections[conn]; ok && conn != nil { subscriptionStore = &wrapper.Subscriptions } else { subscriptionStore = &w.subscriptions @@ -948,7 +948,7 @@ func (w *Websocket) AddSuccessfulSubscriptions(conn Connection, subs ...*subscri } var subscriptionStore **subscription.Store - if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { + if wrapper, ok := w.connections[conn]; ok && conn != nil { subscriptionStore = &wrapper.Subscriptions } else { subscriptionStore = &w.subscriptions @@ -977,7 +977,7 @@ func (w *Websocket) RemoveSubscriptions(conn Connection, subs ...*subscription.S } var subscriptionStore *subscription.Store - if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { + if wrapper, ok := w.connections[conn]; ok && conn != nil { subscriptionStore = wrapper.Subscriptions } else { subscriptionStore = w.subscriptions @@ -1064,7 +1064,7 @@ func checkWebsocketURL(s string) error { // The subscription state is not considered when counting existing subscriptions func (w *Websocket) checkSubscriptions(conn Connection, subs subscription.List) error { var subscriptionStore *subscription.Store - if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { + if wrapper, ok := w.connections[conn]; ok && conn != nil { subscriptionStore = wrapper.Subscriptions } else { subscriptionStore = w.subscriptions diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 3053a0fa840..9d28b7721c4 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -466,7 +466,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { require.NoError(t, multi.SetupNewConnection(amazingCandidate)) amazingConn := multi.getConnectionFromSetup(amazingCandidate) - multi.connectionToWrapper = map[Connection]*ConnectionWrapper{ + multi.connections = map[Connection]*ConnectionWrapper{ amazingConn: multi.connectionManager[0], } @@ -985,7 +985,7 @@ func TestGetChannelDifference(t *testing.T) { require.Equal(t, 1, len(subs)) require.Empty(t, unsubs, "Should get no unsubs") - w.connectionToWrapper = map[Connection]*ConnectionWrapper{ + w.connections = map[Connection]*ConnectionWrapper{ sweetConn: {Setup: &ConnectionSetup{URL: "ws://localhost:8080/ws"}}, } @@ -998,7 +998,7 @@ func TestGetChannelDifference(t *testing.T) { require.Equal(t, 1, len(subs)) require.Empty(t, unsubs, "Should get no unsubs") - err := w.connectionToWrapper[sweetConn].Subscriptions.Add(&subscription.Subscription{Channel: subscription.CandlesChannel}) + err := w.connections[sweetConn].Subscriptions.Add(&subscription.Subscription{Channel: subscription.CandlesChannel}) require.NoError(t, err) subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}}) diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 5152ff2c342..26b20f1ee74 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -55,8 +55,8 @@ type Websocket struct { // for exchanges that differentiate between trading pairs by using different connection endpoints or protocols for various asset classes. // If an exchange does not require such differentiation, all connections may be managed under a single ConnectionWrapper. connectionManager []*ConnectionWrapper - // connectionToWrapper holds a look up table for all connections to their corresponding ConnectionWrapper and subscription holder - connectionToWrapper map[Connection]*ConnectionWrapper + // connections holds a look up table for all connections to their corresponding ConnectionWrapper and subscription holder + connections map[Connection]*ConnectionWrapper subscriptions *subscription.Store From 25f00ceb15947eb69e2fcca3f20c3754dfa46cbc Mon Sep 17 00:00:00 2001 From: shazbert Date: Sat, 23 Nov 2024 06:38:52 +1100 Subject: [PATCH 116/138] gk: engine nits; rm loaded exchange --- engine/engine.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/engine.go b/engine/engine.go index 878b8a9fedc..0c8e2de475e 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -795,7 +795,7 @@ func (bot *Engine) LoadExchange(name string) error { if err != nil { gctlog.Errorf(gctlog.ExchangeSys, "error disabling rate limiter for %s: %v", exch.GetName(), err) } else { - gctlog.Warnf(gctlog.ExchangeSys, "Loaded exchange %s rate limiting has been turned off\n", exch.GetName()) + gctlog.Warnf(gctlog.ExchangeSys, "%s rate limiting has been turned off", exch.GetName()) } } From 058b18432d24c5f20eaaddfbec0ae7275647c530 Mon Sep 17 00:00:00 2001 From: shazbert Date: Sat, 23 Nov 2024 06:49:13 +1100 Subject: [PATCH 117/138] gk: nits; drop WebsocketLoginResponse --- exchanges/gateio/gateio_websocket.go | 24 +++++++++---------- .../gateio_websocket_request_spot_test.go | 7 +++--- .../gateio/gateio_websocket_request_types.go | 9 ------- 3 files changed, 14 insertions(+), 26 deletions(-) diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 4044fdecd50..fac5bd2ce98 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -99,30 +99,29 @@ func (g *Gateio) WsConnectSpot(ctx context.Context, conn stream.Connection) erro // authenticateSpot sends an authentication message to the websocket connection func (g *Gateio) authenticateSpot(ctx context.Context, conn stream.Connection) error { - _, err := g.websocketLogin(ctx, conn, "spot.login") - return err + return g.websocketLogin(ctx, conn, "spot.login") } // websocketLogin authenticates the websocket connection -func (g *Gateio) websocketLogin(ctx context.Context, conn stream.Connection, channel string) (*WebsocketLoginResponse, error) { +func (g *Gateio) websocketLogin(ctx context.Context, conn stream.Connection, channel string) error { if conn == nil { - return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, conn) + return fmt.Errorf("%w: %T", common.ErrNilPointer, conn) } if channel == "" { - return nil, errChannelEmpty + return errChannelEmpty } creds, err := g.GetCredentials(ctx) if err != nil { - return nil, err + return err } tn := time.Now().Unix() msg := "api\n" + channel + "\n" + "\n" + strconv.FormatInt(tn, 10) mac := hmac.New(sha512.New, []byte(creds.Secret)) if _, err = mac.Write([]byte(msg)); err != nil { - return nil, err + return err } signature := hex.EncodeToString(mac.Sum(nil)) @@ -137,24 +136,23 @@ func (g *Gateio) websocketLogin(ctx context.Context, conn stream.Connection, cha resp, err := conn.SendMessageReturnResponse(ctx, request.Unset, req.Payload.RequestID, req) if err != nil { - return nil, err + return err } var inbound WebsocketAPIResponse if err := json.Unmarshal(resp, &inbound); err != nil { - return nil, err + return err } if inbound.Header.Status != "200" { var wsErr WebsocketErrors if err := json.Unmarshal(inbound.Data, &wsErr.Errors); err != nil { - return nil, err + return err } - return nil, fmt.Errorf("%s: %s", wsErr.Errors.Label, wsErr.Errors.Message) + return fmt.Errorf("%s: %s", wsErr.Errors.Label, wsErr.Errors.Message) } - var result WebsocketLoginResponse - return &result, json.Unmarshal(inbound.Data, &result) + return nil } func (g *Gateio) generateWsSignature(secret, event, channel string, t int64) (string, error) { diff --git a/exchanges/gateio/gateio_websocket_request_spot_test.go b/exchanges/gateio/gateio_websocket_request_spot_test.go index 004806f75e5..1ee75f34859 100644 --- a/exchanges/gateio/gateio_websocket_request_spot_test.go +++ b/exchanges/gateio/gateio_websocket_request_spot_test.go @@ -18,10 +18,10 @@ import ( func TestWebsocketLogin(t *testing.T) { t.Parallel() - _, err := g.websocketLogin(context.Background(), nil, "") + err := g.websocketLogin(context.Background(), nil, "") require.ErrorIs(t, err, common.ErrNilPointer) - _, err = g.websocketLogin(context.Background(), &stream.WebsocketConnection{}, "") + err = g.websocketLogin(context.Background(), &stream.WebsocketConnection{}, "") require.ErrorIs(t, err, errChannelEmpty) sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) @@ -32,9 +32,8 @@ func TestWebsocketLogin(t *testing.T) { demonstrationConn, err := g.Websocket.GetConnection(asset.Spot) require.NoError(t, err) - got, err := g.websocketLogin(context.Background(), demonstrationConn, "spot.login") + err = g.websocketLogin(context.Background(), demonstrationConn, "spot.login") require.NoError(t, err) - require.NotEmpty(t, got) } func TestWebsocketOrderPlaceSpot(t *testing.T) { diff --git a/exchanges/gateio/gateio_websocket_request_types.go b/exchanges/gateio/gateio_websocket_request_types.go index b16aaaa0d48..165eea41cba 100644 --- a/exchanges/gateio/gateio_websocket_request_types.go +++ b/exchanges/gateio/gateio_websocket_request_types.go @@ -52,15 +52,6 @@ type WebsocketErrors struct { } `json:"errs"` } -// WebsocketLoginResponse defines a websocket login response when authenticating -// the connection. -type WebsocketLoginResponse struct { - Result struct { - APIKey string `json:"api_key"` - UID string `json:"uid"` - } `json:"result"` -} - // WebsocketOrder defines a websocket order type WebsocketOrder struct { Text string `json:"text"` From a87c4387ecec97906103b2856e8298d7e10f72c2 Mon Sep 17 00:00:00 2001 From: shazbert Date: Sat, 23 Nov 2024 07:16:06 +1100 Subject: [PATCH 118/138] stream: Add match method EnsureMatchWithData --- exchanges/bitfinex/bitfinex_websocket.go | 20 ++++++++++++-------- exchanges/bitmex/bitmex_websocket.go | 5 +++-- exchanges/gateio/gateio_websocket.go | 10 ++-------- exchanges/kraken/kraken_websocket.go | 2 +- exchanges/kucoin/kucoin_websocket.go | 8 ++------ exchanges/stream/stream_match.go | 12 ++++++++++++ exchanges/stream/stream_match_test.go | 15 +++++++++++++++ exchanges/stream/websocket.go | 1 - 8 files changed, 47 insertions(+), 26 deletions(-) diff --git a/exchanges/bitfinex/bitfinex_websocket.go b/exchanges/bitfinex/bitfinex_websocket.go index 74c04afc2d3..1d5425f33b3 100644 --- a/exchanges/bitfinex/bitfinex_websocket.go +++ b/exchanges/bitfinex/bitfinex_websocket.go @@ -456,17 +456,20 @@ func (b *Bitfinex) handleWSEvent(respRaw []byte) error { if err != nil { return fmt.Errorf("%w 'chanId': %w from message: %s", errParsingWSField, err, respRaw) } - if !b.Websocket.Match.IncomingWithData("unsubscribe:"+chanID, respRaw) { - return fmt.Errorf("%w: unsubscribe:%v", stream.ErrNoMessageListener, chanID) + err = b.Websocket.Match.EnsureMatchWithData("unsubscribe:"+chanID, respRaw) + if err != nil { + return fmt.Errorf("%w: unsubscribe:%v", err, chanID) } case wsEventError: if subID, err := jsonparser.GetUnsafeString(respRaw, "subId"); err == nil { - if !b.Websocket.Match.IncomingWithData("subscribe:"+subID, respRaw) { - return fmt.Errorf("%w: subscribe:%v", stream.ErrNoMessageListener, subID) + err = b.Websocket.Match.EnsureMatchWithData("subscribe:"+subID, respRaw) + if err != nil { + return fmt.Errorf("%w: subscribe:%v", err, subID) } } else if chanID, err := jsonparser.GetUnsafeString(respRaw, "chanId"); err == nil { - if !b.Websocket.Match.IncomingWithData("unsubscribe:"+chanID, respRaw) { - return fmt.Errorf("%w: unsubscribe:%v", stream.ErrNoMessageListener, chanID) + err = b.Websocket.Match.EnsureMatchWithData("unsubscribe:"+chanID, respRaw) + if err != nil { + return fmt.Errorf("%w: unsubscribe:%v", err, chanID) } } else { return fmt.Errorf("unknown channel error; Message: %s", respRaw) @@ -538,8 +541,9 @@ func (b *Bitfinex) handleWSSubscribed(respRaw []byte) error { if b.Verbose { log.Debugf(log.ExchangeSys, "%s Subscribed to Channel: %s Pair: %s ChannelID: %d\n", b.Name, c.Channel, c.Pairs, chanID) } - if !b.Websocket.Match.IncomingWithData("subscribe:"+subID, respRaw) { - return fmt.Errorf("%w: subscribe:%v", stream.ErrNoMessageListener, subID) + err = b.Websocket.Match.EnsureMatchWithData("subscribe:"+subID, respRaw) + if err != nil { + return fmt.Errorf("%w: subscribe:%v", err, subID) } return nil } diff --git a/exchanges/bitmex/bitmex_websocket.go b/exchanges/bitmex/bitmex_websocket.go index a18fa161216..6c6cd3c924c 100644 --- a/exchanges/bitmex/bitmex_websocket.go +++ b/exchanges/bitmex/bitmex_websocket.go @@ -170,8 +170,9 @@ func (b *Bitmex) wsHandleData(respRaw []byte) error { if e2 != nil { return fmt.Errorf("%w parsing stream", e2) } - if !b.Websocket.Match.IncomingWithData(op+":"+streamID, msg) { - return fmt.Errorf("%w: %s:%s", stream.ErrNoMessageListener, op, streamID) + err = b.Websocket.Match.EnsureMatchWithData(op+":"+streamID, msg) + if err != nil { + return fmt.Errorf("%w: %s:%s", err, op, streamID) } return nil } diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index fac5bd2ce98..105784e42aa 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -172,17 +172,11 @@ func (g *Gateio) WsHandleSpotData(_ context.Context, respRaw []byte) error { } if push.RequestID != "" { - if !g.Websocket.Match.IncomingWithData(push.RequestID, respRaw) { - return fmt.Errorf("%w for requestID %v", stream.ErrNoMessageListener, push.RequestID) - } - return nil + return g.Websocket.Match.EnsureMatchWithData(push.RequestID, respRaw) } if push.Event == subscribeEvent || push.Event == unsubscribeEvent { - if !g.Websocket.Match.IncomingWithData(push.ID, respRaw) { - return fmt.Errorf("%w couldn't match subscription message with ID: %d", stream.ErrNoMessageListener, push.ID) - } - return nil + return g.Websocket.Match.EnsureMatchWithData(push.ID, respRaw) } switch push.Channel { // TODO: Convert function params below to only use push.Result diff --git a/exchanges/kraken/kraken_websocket.go b/exchanges/kraken/kraken_websocket.go index 22088590463..5b276055181 100644 --- a/exchanges/kraken/kraken_websocket.go +++ b/exchanges/kraken/kraken_websocket.go @@ -231,7 +231,7 @@ func (k *Kraken) wsHandleData(respRaw []byte) error { return nil case krakenWsCancelOrderStatus, krakenWsCancelAllOrderStatus, krakenWsAddOrderStatus, krakenWsSubscriptionStatus: // All of these should have found a listener already - return fmt.Errorf("%w: %s %v", stream.ErrNoMessageListener, event, reqID) + return fmt.Errorf("%w: %s %v", stream.ErrSignatureNotMatched, event, reqID) case krakenWsSystemStatus: return k.wsProcessSystemStatus(respRaw) default: diff --git a/exchanges/kucoin/kucoin_websocket.go b/exchanges/kucoin/kucoin_websocket.go index 8b4aa535e26..956239e25a5 100644 --- a/exchanges/kucoin/kucoin_websocket.go +++ b/exchanges/kucoin/kucoin_websocket.go @@ -215,18 +215,14 @@ func (ku *Kucoin) wsReadData() { // wsHandleData processes a websocket incoming data. func (ku *Kucoin) wsHandleData(respData []byte) error { resp := WsPushData{} - err := json.Unmarshal(respData, &resp) - if err != nil { + if err := json.Unmarshal(respData, &resp); err != nil { return err } if resp.Type == "pong" || resp.Type == "welcome" { return nil } if resp.ID != "" { - if !ku.Websocket.Match.IncomingWithData("msgID:"+resp.ID, respData) { - return fmt.Errorf("%w: %s", stream.ErrNoMessageListener, resp.ID) - } - return nil + return ku.Websocket.Match.EnsureMatchWithData("msgID:"+resp.ID, respData) } topicInfo := strings.Split(resp.Topic, ":") switch topicInfo[0] { diff --git a/exchanges/stream/stream_match.go b/exchanges/stream/stream_match.go index a7b8a10bfab..d8fa488ae6a 100644 --- a/exchanges/stream/stream_match.go +++ b/exchanges/stream/stream_match.go @@ -2,9 +2,12 @@ package stream import ( "errors" + "fmt" "sync" ) +var ErrSignatureNotMatched = errors.New("websocket response to request signature not matched") + var ( errSignatureCollision = errors.New("signature collision") errInvalidBufferSize = errors.New("buffer size must be positive") @@ -47,6 +50,15 @@ func (m *Match) IncomingWithData(signature any, data []byte) bool { return true } +// EnsureMatchWithData validates that incoming data matches a request's signature. +// If a match is found, the data is processed; otherwise, it returns an error. +func (m *Match) EnsureMatchWithData(signature any, data []byte) error { + if m.IncomingWithData(signature, data) { + return nil + } + return fmt.Errorf("'%v' %w with data %v", signature, ErrSignatureNotMatched, string(data)) +} + // Set the signature response channel for incoming data func (m *Match) Set(signature any, bufSize int) (<-chan []byte, error) { if bufSize <= 0 { diff --git a/exchanges/stream/stream_match_test.go b/exchanges/stream/stream_match_test.go index b7a21b23c05..5a873ce80c6 100644 --- a/exchanges/stream/stream_match_test.go +++ b/exchanges/stream/stream_match_test.go @@ -51,3 +51,18 @@ func TestRemoveSignature(t *testing.T) { t.Fatal("Should be able to read from a closed channel") } } + +func TestEnsureMatchWithData(t *testing.T) { + t.Parallel() + match := NewMatch() + err := match.EnsureMatchWithData("hello", []byte("world")) + require.ErrorIs(t, err, ErrSignatureNotMatched, "Should error on unmatched signature") + assert.Contains(t, err.Error(), "world", "Should contain the data in the error message") + assert.Contains(t, err.Error(), "hello", "Should contain the signature in the error message") + + ch, err := match.Set("hello", 1) + require.NoError(t, err, "Set must not error") + err = match.EnsureMatchWithData("hello", []byte("world")) + require.NoError(t, err, "Should not error on matched signature") + assert.Equal(t, "world", string(<-ch)) +} diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 18bed4ee457..22746c426fe 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -28,7 +28,6 @@ var ( ErrUnsubscribeFailure = errors.New("unsubscribe failure") ErrAlreadyDisabled = errors.New("websocket already disabled") ErrNotConnected = errors.New("websocket is not connected") - ErrNoMessageListener = errors.New("websocket listener not found for message") ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature") ErrRequestRouteNotFound = errors.New("request route not found") ErrSignatureNotSet = errors.New("signature not set") From 6a2fbdec024a0e192a9255a6b23d9fd9eaf82338 Mon Sep 17 00:00:00 2001 From: shazbert Date: Sat, 23 Nov 2024 07:39:33 +1100 Subject: [PATCH 119/138] gk: nits; rn Inspect to IsFinal --- exchanges/gateio/gateio_websocket_request_spot.go | 11 ++++++----- exchanges/stream/stream_types.go | 4 ++-- exchanges/stream/websocket_connection.go | 2 +- exchanges/stream/websocket_test.go | 2 +- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/exchanges/gateio/gateio_websocket_request_spot.go b/exchanges/gateio/gateio_websocket_request_spot.go index e0011a6fcc5..52d11d0197b 100644 --- a/exchanges/gateio/gateio_websocket_request_spot.go +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -180,7 +180,7 @@ func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel string, connS }, } - responses, err := conn.SendMessageReturnResponsesWithInspector(ctx, request.Unset, req.Payload.RequestID, req, expectedResponses, g) + responses, err := conn.SendMessageReturnResponsesWithInspector(ctx, request.Unset, req.Payload.RequestID, req, expectedResponses, wsRespAckInspector{}) if err != nil { return err } @@ -210,9 +210,10 @@ func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel string, connS return json.Unmarshal(inbound.Data, &funnelResult{Result: result}) } -// Inspect checks the payload for an ack, it returns true if the -// payload does not contain an ack. This will force the cancellation of further -// waiting for responses. -func (g *Gateio) Inspect(data []byte) bool { +type wsRespAckInspector struct{} + +// IsFinal checks the payload for an ack, it returns true if the payload does not contain an ack. +// This will force the cancellation of further waiting for responses. +func (wsRespAckInspector) IsFinal(data []byte) bool { return !strings.Contains(string(data), "ack") } diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index bf72a42afda..e1c48910931 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -40,9 +40,9 @@ type Connection interface { } // Inspector is used to verify messages via SendMessageReturnResponsesWithInspection -// It inspects the []bytes websocket message and returns true if it is the appropriate message to action +// It inspects the []bytes websocket message and returns true if the message is the final message in a sequence of expected messages type Inspector interface { - Inspect([]byte) bool + IsFinal([]byte) bool } // Response defines generalised data from the stream connection diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 10cd8dc8828..92aace90ef1 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -349,7 +349,7 @@ func (w *WebsocketConnection) waitForResponses(ctx context.Context, signature an case resp := <-ch: resps = append(resps, resp) // Checks recently received message to determine if this is in fact the final message in a sequence of messages. - if messageInspector != nil && messageInspector.Inspect(resp) { + if messageInspector != nil && messageInspector.IsFinal(resp) { w.Match.RemoveSignature(signature) return resps, nil } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 9d28b7721c4..a1e4bdcf15d 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -771,7 +771,7 @@ func TestSendMessageReturnResponse(t *testing.T) { type inspection struct{} -func (i inspection) Inspect([]byte) bool { return false } +func (inspection) IsFinal([]byte) bool { return false } type reporter struct { name string From cb965218c36e27c0547c98ccd2a2c9cb7e6ecb16 Mon Sep 17 00:00:00 2001 From: shazbert Date: Sat, 23 Nov 2024 07:51:48 +1100 Subject: [PATCH 120/138] gk: nits; rn to MessageFilter --- exchanges/gateio/gateio_wrapper.go | 82 +++++++++++++++--------------- exchanges/stream/stream_types.go | 7 ++- exchanges/stream/websocket.go | 76 +++++++++++++-------------- exchanges/stream/websocket_test.go | 8 +-- 4 files changed, 86 insertions(+), 87 deletions(-) diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 5291449b139..8a396b39f85 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -209,18 +209,18 @@ func (g *Gateio) Setup(exch *config.Exchange) error { } // Spot connection err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ - URL: gateioWebsocketEndpoint, - RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - Handler: g.WsHandleSpotData, - Subscriber: g.Subscribe, - Unsubscriber: g.Unsubscribe, - GenerateSubscriptions: g.generateSubscriptionsSpot, - Connector: g.WsConnectSpot, - Authenticate: g.authenticateSpot, - WrapperDefinedConnectionSignature: asset.Spot, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + URL: gateioWebsocketEndpoint, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: g.WsHandleSpotData, + Subscriber: g.Subscribe, + Unsubscriber: g.Unsubscribe, + GenerateSubscriptions: g.generateSubscriptionsSpot, + Connector: g.WsConnectSpot, + Authenticate: g.authenticateSpot, + MessageFilter: asset.Spot, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { return err @@ -234,12 +234,12 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Handler: func(ctx context.Context, incoming []byte) error { return g.WsHandleFuturesData(ctx, incoming, asset.Futures) }, - Subscriber: g.FuturesSubscribe, - Unsubscriber: g.FuturesUnsubscribe, - GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.USDT) }, - Connector: g.WsFuturesConnect, - WrapperDefinedConnectionSignature: asset.USDTMarginedFutures, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + Subscriber: g.FuturesSubscribe, + Unsubscriber: g.FuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.USDT) }, + Connector: g.WsFuturesConnect, + MessageFilter: asset.USDTMarginedFutures, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { return err @@ -254,12 +254,12 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Handler: func(ctx context.Context, incoming []byte) error { return g.WsHandleFuturesData(ctx, incoming, asset.Futures) }, - Subscriber: g.FuturesSubscribe, - Unsubscriber: g.FuturesUnsubscribe, - GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.BTC) }, - Connector: g.WsFuturesConnect, - WrapperDefinedConnectionSignature: asset.CoinMarginedFutures, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + Subscriber: g.FuturesSubscribe, + Unsubscriber: g.FuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.BTC) }, + Connector: g.WsFuturesConnect, + MessageFilter: asset.CoinMarginedFutures, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { return err @@ -275,12 +275,12 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Handler: func(ctx context.Context, incoming []byte) error { return g.WsHandleFuturesData(ctx, incoming, asset.DeliveryFutures) }, - Subscriber: g.DeliveryFuturesSubscribe, - Unsubscriber: g.DeliveryFuturesUnsubscribe, - GenerateSubscriptions: g.GenerateDeliveryFuturesDefaultSubscriptions, - Connector: g.WsDeliveryFuturesConnect, - WrapperDefinedConnectionSignature: asset.DeliveryFutures, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + Subscriber: g.DeliveryFuturesSubscribe, + Unsubscriber: g.DeliveryFuturesUnsubscribe, + GenerateSubscriptions: g.GenerateDeliveryFuturesDefaultSubscriptions, + Connector: g.WsDeliveryFuturesConnect, + MessageFilter: asset.DeliveryFutures, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { return err @@ -288,17 +288,17 @@ func (g *Gateio) Setup(exch *config.Exchange) error { // Futures connection - Options return g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ - URL: optionsWebsocketURL, - RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - Handler: g.WsHandleOptionsData, - Subscriber: g.OptionsSubscribe, - Unsubscriber: g.OptionsUnsubscribe, - GenerateSubscriptions: g.GenerateOptionsDefaultSubscriptions, - Connector: g.WsOptionsConnect, - WrapperDefinedConnectionSignature: asset.Options, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + URL: optionsWebsocketURL, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: g.WsHandleOptionsData, + Subscriber: g.OptionsSubscribe, + Unsubscriber: g.OptionsUnsubscribe, + GenerateSubscriptions: g.GenerateOptionsDefaultSubscriptions, + Connector: g.WsOptionsConnect, + MessageFilter: asset.Options, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) } diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index e1c48910931..832e74c5526 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -86,10 +86,9 @@ type ConnectionSetup struct { BespokeGenerateMessageID func(highPrecision bool) int64 // Authenticate will be called to authenticate the connection Authenticate func(ctx context.Context, conn Connection) error - // WrapperDefinedConnectionSignature is any type that will match to a specific connection. This could be an asset - // type `asset.Spot`, a string type denoting the individual URL, an authenticated or unauthenticated string or a - // mixture of these. - WrapperDefinedConnectionSignature any + // MessageFilter defines the criteria used to match messages to a specific connection. + // The filter enables precise routing and handling of messages for distinct connection contexts. + MessageFilter any } // ConnectionWrapper contains the connection setup details to be used when diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 22746c426fe..1e716d9f222 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -36,40 +36,40 @@ var ( // Private websocket errors var ( - errExchangeConfigIsNil = errors.New("exchange config is nil") - errWebsocketIsNil = errors.New("websocket is nil") - errWebsocketSetupIsNil = errors.New("websocket setup is nil") - errWebsocketAlreadyInitialised = errors.New("websocket already initialised") - errWebsocketAlreadyEnabled = errors.New("websocket already enabled") - errWebsocketFeaturesIsUnset = errors.New("websocket features is unset") - errConfigFeaturesIsNil = errors.New("exchange config features is nil") - errDefaultURLIsEmpty = errors.New("default url is empty") - errRunningURLIsEmpty = errors.New("running url cannot be empty") - errInvalidWebsocketURL = errors.New("invalid websocket url") - errExchangeConfigNameEmpty = errors.New("exchange config name empty") - errInvalidTrafficTimeout = errors.New("invalid traffic timeout") - errTrafficAlertNil = errors.New("traffic alert is nil") - errWebsocketSubscriberUnset = errors.New("websocket subscriber function needs to be set") - errWebsocketUnsubscriberUnset = errors.New("websocket unsubscriber functionality allowed but unsubscriber function not set") - errWebsocketConnectorUnset = errors.New("websocket connector function not set") - errWebsocketDataHandlerUnset = errors.New("websocket data handler not set") - errReadMessageErrorsNil = errors.New("read message errors is nil") - errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set") - errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit") - errInvalidMaxSubscriptions = errors.New("max subscriptions cannot be less than 0") - errSameProxyAddress = errors.New("cannot set proxy address to the same address") - errNoConnectFunc = errors.New("websocket connect func not set") - errAlreadyConnected = errors.New("websocket already connected") - errCannotShutdown = errors.New("websocket cannot shutdown") - errAlreadyReconnecting = errors.New("websocket in the process of reconnection") - errConnSetup = errors.New("error in connection setup") - errNoPendingConnections = errors.New("no pending connections, call SetupNewConnection first") - errConnectionWrapperDuplication = errors.New("connection wrapper duplication") - errCannotChangeConnectionURL = errors.New("cannot change connection URL when using multi connection management") - errExchangeConfigEmpty = errors.New("exchange config is empty") - errCannotObtainOutboundConnection = errors.New("cannot obtain outbound connection") - errConnectionSignatureNotSet = errors.New("connection signature not set") - errWrapperDefinedConnectionSignatureNotComparable = errors.New("wrapper defined connection signature is not comparable") + errExchangeConfigIsNil = errors.New("exchange config is nil") + errWebsocketIsNil = errors.New("websocket is nil") + errWebsocketSetupIsNil = errors.New("websocket setup is nil") + errWebsocketAlreadyInitialised = errors.New("websocket already initialised") + errWebsocketAlreadyEnabled = errors.New("websocket already enabled") + errWebsocketFeaturesIsUnset = errors.New("websocket features is unset") + errConfigFeaturesIsNil = errors.New("exchange config features is nil") + errDefaultURLIsEmpty = errors.New("default url is empty") + errRunningURLIsEmpty = errors.New("running url cannot be empty") + errInvalidWebsocketURL = errors.New("invalid websocket url") + errExchangeConfigNameEmpty = errors.New("exchange config name empty") + errInvalidTrafficTimeout = errors.New("invalid traffic timeout") + errTrafficAlertNil = errors.New("traffic alert is nil") + errWebsocketSubscriberUnset = errors.New("websocket subscriber function needs to be set") + errWebsocketUnsubscriberUnset = errors.New("websocket unsubscriber functionality allowed but unsubscriber function not set") + errWebsocketConnectorUnset = errors.New("websocket connector function not set") + errWebsocketDataHandlerUnset = errors.New("websocket data handler not set") + errReadMessageErrorsNil = errors.New("read message errors is nil") + errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set") + errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit") + errInvalidMaxSubscriptions = errors.New("max subscriptions cannot be less than 0") + errSameProxyAddress = errors.New("cannot set proxy address to the same address") + errNoConnectFunc = errors.New("websocket connect func not set") + errAlreadyConnected = errors.New("websocket already connected") + errCannotShutdown = errors.New("websocket cannot shutdown") + errAlreadyReconnecting = errors.New("websocket in the process of reconnection") + errConnSetup = errors.New("error in connection setup") + errNoPendingConnections = errors.New("no pending connections, call SetupNewConnection first") + errConnectionWrapperDuplication = errors.New("connection wrapper duplication") + errCannotChangeConnectionURL = errors.New("cannot change connection URL when using multi connection management") + errExchangeConfigEmpty = errors.New("exchange config is empty") + errCannotObtainOutboundConnection = errors.New("cannot obtain outbound connection") + errConnectionSignatureNotSet = errors.New("connection signature not set") + errMessageFilterNotComparable = errors.New("message filter is not comparable") ) var globalReporter Reporter @@ -265,15 +265,15 @@ func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { return fmt.Errorf("%w: %w", errConnSetup, errWebsocketDataHandlerUnset) } - if c.WrapperDefinedConnectionSignature != nil && !reflect.TypeOf(c.WrapperDefinedConnectionSignature).Comparable() { - return errWrapperDefinedConnectionSignatureNotComparable + if c.MessageFilter != nil && !reflect.TypeOf(c.MessageFilter).Comparable() { + return errMessageFilterNotComparable } for x := range w.connectionManager { // Below allows for multiple connections to the same URL with different outbound request signatures. This // allows for easier determination of inbound and outbound messages. e.g. Gateio cross_margin, margin on // a spot connection. - if w.connectionManager[x].Setup.URL == c.URL && c.WrapperDefinedConnectionSignature == w.connectionManager[x].Setup.WrapperDefinedConnectionSignature { + if w.connectionManager[x].Setup.URL == c.URL && c.MessageFilter == w.connectionManager[x].Setup.MessageFilter { return fmt.Errorf("%w: %w", errConnSetup, errConnectionWrapperDuplication) } } @@ -1286,7 +1286,7 @@ func (w *Websocket) GetConnection(connSignature any) (Connection, error) { } for _, wrapper := range w.connectionManager { - if wrapper.Setup.WrapperDefinedConnectionSignature == connSignature { + if wrapper.Setup.MessageFilter == connSignature { if wrapper.Connection == nil { return nil, fmt.Errorf("%s: %s %w: %v", w.exchangeName, wrapper.Setup.URL, ErrNotConnected, connSignature) } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index a1e4bdcf15d..8ca6709a10e 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -1239,11 +1239,11 @@ func TestSetupNewConnection(t *testing.T) { require.ErrorIs(t, err, errWebsocketDataHandlerUnset) connSetup.Handler = func(context.Context, []byte) error { return nil } - connSetup.WrapperDefinedConnectionSignature = []string{"slices are super naughty and not comparable"} + connSetup.MessageFilter = []string{"slices are super naughty and not comparable"} err = multi.SetupNewConnection(connSetup) - require.ErrorIs(t, err, errWrapperDefinedConnectionSignatureNotComparable) + require.ErrorIs(t, err, errMessageFilterNotComparable) - connSetup.WrapperDefinedConnectionSignature = "comparable string signature" + connSetup.MessageFilter = "comparable string signature" err = multi.SetupNewConnection(connSetup) require.NoError(t, err) @@ -1525,7 +1525,7 @@ func TestGetConnection(t *testing.T) { require.ErrorIs(t, err, ErrRequestRouteNotFound) ws.connectionManager = []*ConnectionWrapper{{ - Setup: &ConnectionSetup{WrapperDefinedConnectionSignature: "testURL", URL: "testURL"}, + Setup: &ConnectionSetup{MessageFilter: "testURL", URL: "testURL"}, }} _, err = ws.GetConnection("testURL") From 3417a55ba670c5773bef167d55657cc7c0aa8b58 Mon Sep 17 00:00:00 2001 From: shazbert Date: Sun, 24 Nov 2024 08:03:29 +1100 Subject: [PATCH 121/138] linter: fix --- exchanges/bitfinex/bitfinex_websocket.go | 10 ++++------ exchanges/stream/stream_match.go | 1 + 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/exchanges/bitfinex/bitfinex_websocket.go b/exchanges/bitfinex/bitfinex_websocket.go index 1d5425f33b3..0b213e92d33 100644 --- a/exchanges/bitfinex/bitfinex_websocket.go +++ b/exchanges/bitfinex/bitfinex_websocket.go @@ -534,18 +534,16 @@ func (b *Bitfinex) handleWSSubscribed(respRaw []byte) error { c.Key = int(chanID) // subscribeToChan removes the old subID keyed Subscription - if err := b.Websocket.AddSuccessfulSubscriptions(b.Websocket.Conn, c); err != nil { + err = b.Websocket.AddSuccessfulSubscriptions(b.Websocket.Conn, c) + if err != nil { return fmt.Errorf("%w: %w subID: %s", stream.ErrSubscriptionFailure, err, subID) } if b.Verbose { log.Debugf(log.ExchangeSys, "%s Subscribed to Channel: %s Pair: %s ChannelID: %d\n", b.Name, c.Channel, c.Pairs, chanID) } - err = b.Websocket.Match.EnsureMatchWithData("subscribe:"+subID, respRaw) - if err != nil { - return fmt.Errorf("%w: subscribe:%v", err, subID) - } - return nil + + return b.Websocket.Match.EnsureMatchWithData("subscribe:"+subID, respRaw) } func (b *Bitfinex) handleWSChannelUpdate(s *subscription.Subscription, eventType string, d []interface{}) error { diff --git a/exchanges/stream/stream_match.go b/exchanges/stream/stream_match.go index d8fa488ae6a..6d8949e08e6 100644 --- a/exchanges/stream/stream_match.go +++ b/exchanges/stream/stream_match.go @@ -6,6 +6,7 @@ import ( "sync" ) +// ErrSignatureNotMatched is returned when a signature does not match a request var ErrSignatureNotMatched = errors.New("websocket response to request signature not matched") var ( From 70781cba9a8817ea95c2b9ef2db427117020102b Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 28 Nov 2024 10:26:43 +1100 Subject: [PATCH 122/138] gateio: update rate limit definitions (cherry-pick) --- exchanges/gateio/gateio.go | 506 +++++++++++------------------ exchanges/gateio/gateio_test.go | 2 +- exchanges/gateio/gateio_wrapper.go | 4 +- exchanges/gateio/ratelimiter.go | 439 ++++++++++++++++++++++--- 4 files changed, 589 insertions(+), 362 deletions(-) diff --git a/exchanges/gateio/gateio.go b/exchanges/gateio/gateio.go index 11d0dcd34c8..e4187d606fa 100644 --- a/exchanges/gateio/gateio.go +++ b/exchanges/gateio/gateio.go @@ -186,13 +186,13 @@ func (g *Gateio) CreateNewSubAccount(ctx context.Context, arg SubAccountParams) return nil, errors.New("login name can not be empty") } var response *SubAccount - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotDefaultEPL, http.MethodPost, subAccounts, nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, subAccountEPL, http.MethodPost, subAccounts, nil, &arg, &response) } // GetSubAccounts retrieves list of sub-accounts for given account func (g *Gateio) GetSubAccounts(ctx context.Context) ([]SubAccount, error) { var response []SubAccount - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotDefaultEPL, http.MethodGet, subAccounts, nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, subAccountEPL, http.MethodGet, subAccounts, nil, nil, &response) } // GetSingleSubAccount retrieves a single sub-account for given account @@ -201,8 +201,7 @@ func (g *Gateio) GetSingleSubAccount(ctx context.Context, userID string) (*SubAc return nil, errors.New("user ID can not be empty") } var response *SubAccount - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotDefaultEPL, http.MethodGet, - subAccounts+"/"+userID, nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, subAccountEPL, http.MethodGet, subAccounts+"/"+userID, nil, nil, &response) } // CreateAPIKeysOfSubAccount creates a sub-account for the sub-account @@ -217,18 +216,18 @@ func (g *Gateio) CreateAPIKeysOfSubAccount(ctx context.Context, arg CreateAPIKey return nil, errors.New("sub-account key information is required") } var resp *CreateAPIKeyResponse - return resp, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPlaceOrdersEPL, http.MethodPost, subAccountsPath+strconv.FormatInt(arg.SubAccountUserID, 10)+"/keys", nil, &arg, &resp) + return resp, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, subAccountEPL, http.MethodPost, subAccountsPath+strconv.FormatInt(arg.SubAccountUserID, 10)+"/keys", nil, &arg, &resp) } // GetAllAPIKeyOfSubAccount list all API Key of the sub-account func (g *Gateio) GetAllAPIKeyOfSubAccount(ctx context.Context, userID int64) ([]CreateAPIKeyResponse, error) { var resp []CreateAPIKeyResponse - return resp, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, subAccountsPath+strconv.FormatInt(userID, 10)+"/keys", nil, nil, &resp) + return resp, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, subAccountEPL, http.MethodGet, subAccountsPath+strconv.FormatInt(userID, 10)+"/keys", nil, nil, &resp) } // UpdateAPIKeyOfSubAccount update API key of the sub-account func (g *Gateio) UpdateAPIKeyOfSubAccount(ctx context.Context, subAccountAPIKey string, arg CreateAPIKeySubAccountParams) error { - return g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPlaceOrdersEPL, http.MethodPut, subAccountsPath+strconv.FormatInt(arg.SubAccountUserID, 10)+"/keys/"+subAccountAPIKey, nil, &arg, nil) + return g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, subAccountEPL, http.MethodPut, subAccountsPath+strconv.FormatInt(arg.SubAccountUserID, 10)+"/keys/"+subAccountAPIKey, nil, &arg, nil) } // GetAPIKeyOfSubAccount retrieves the API Key of the sub-account @@ -240,7 +239,7 @@ func (g *Gateio) GetAPIKeyOfSubAccount(ctx context.Context, subAccountUserID int return nil, errMissingAPIKey } var resp *CreateAPIKeyResponse - return resp, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, subAccountsPath+strconv.FormatInt(subAccountUserID, 10)+"/keys/"+apiKey, nil, nil, &resp) + return resp, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, subAccountEPL, http.MethodGet, subAccountsPath+strconv.FormatInt(subAccountUserID, 10)+"/keys/"+apiKey, nil, nil, &resp) } // LockSubAccount locks the sub-account @@ -248,7 +247,7 @@ func (g *Gateio) LockSubAccount(ctx context.Context, subAccountUserID int64) err if subAccountUserID == 0 { return errInvalidSubAccountUserID } - return g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodPost, subAccountsPath+strconv.FormatInt(subAccountUserID, 10)+"/lock", nil, nil, nil) + return g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, subAccountEPL, http.MethodPost, subAccountsPath+strconv.FormatInt(subAccountUserID, 10)+"/lock", nil, nil, nil) } // UnlockSubAccount locks the sub-account @@ -256,7 +255,7 @@ func (g *Gateio) UnlockSubAccount(ctx context.Context, subAccountUserID int64) e if subAccountUserID == 0 { return errInvalidSubAccountUserID } - return g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodPost, subAccountsPath+strconv.FormatInt(subAccountUserID, 10)+"/unlock", nil, nil, nil) + return g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, subAccountEPL, http.MethodPost, subAccountsPath+strconv.FormatInt(subAccountUserID, 10)+"/unlock", nil, nil, nil) } // ***************************************** Spot ************************************** @@ -264,7 +263,7 @@ func (g *Gateio) UnlockSubAccount(ctx context.Context, subAccountUserID int64) e // ListSpotCurrencies to retrieve detailed list of each currency. func (g *Gateio) ListSpotCurrencies(ctx context.Context) ([]CurrencyInfo, error) { var resp []CurrencyInfo - return resp, g.SendHTTPRequest(ctx, exchange.RestSpot, spotDefaultEPL, gateioSpotCurrencies, &resp) + return resp, g.SendHTTPRequest(ctx, exchange.RestSpot, publicCurrenciesSpotEPL, gateioSpotCurrencies, &resp) } // GetCurrencyDetail details of a specific currency. @@ -273,14 +272,13 @@ func (g *Gateio) GetCurrencyDetail(ctx context.Context, ccy currency.Code) (*Cur return nil, currency.ErrCurrencyCodeEmpty } var resp *CurrencyInfo - return resp, g.SendHTTPRequest(ctx, exchange.RestSpot, spotDefaultEPL, - gateioSpotCurrencies+"/"+ccy.String(), &resp) + return resp, g.SendHTTPRequest(ctx, exchange.RestSpot, publicCurrenciesSpotEPL, gateioSpotCurrencies+"/"+ccy.String(), &resp) } // ListSpotCurrencyPairs retrieve all currency pairs supported by the exchange. func (g *Gateio) ListSpotCurrencyPairs(ctx context.Context) ([]CurrencyPairDetail, error) { var resp []CurrencyPairDetail - return resp, g.SendHTTPRequest(ctx, exchange.RestSpot, spotDefaultEPL, gateioSpotCurrencyPairs, &resp) + return resp, g.SendHTTPRequest(ctx, exchange.RestSpot, publicListCurrencyPairsSpotEPL, gateioSpotCurrencyPairs, &resp) } // GetCurrencyPairDetail to get details of a specific order for spot/margin accounts. @@ -289,8 +287,7 @@ func (g *Gateio) GetCurrencyPairDetail(ctx context.Context, currencyPair string) return nil, currency.ErrCurrencyPairEmpty } var resp *CurrencyPairDetail - return resp, g.SendHTTPRequest(ctx, exchange.RestSpot, spotDefaultEPL, - gateioSpotCurrencyPairs+"/"+currencyPair, &resp) + return resp, g.SendHTTPRequest(ctx, exchange.RestSpot, publicCurrencyPairDetailSpotEPL, gateioSpotCurrencyPairs+"/"+currencyPair, &resp) } // GetTickers retrieve ticker information @@ -306,7 +303,7 @@ func (g *Gateio) GetTickers(ctx context.Context, currencyPair, timezone string) params.Set("timezone", timezone) } var tickers []Ticker - return tickers, g.SendHTTPRequest(ctx, exchange.RestSpot, spotDefaultEPL, common.EncodeURLValues(gateioSpotTickers, params), &tickers) + return tickers, g.SendHTTPRequest(ctx, exchange.RestSpot, publicTickersSpotEPL, common.EncodeURLValues(gateioSpotTickers, params), &tickers) } // GetTicker retrieves a single ticker information for a currency pair. @@ -419,7 +416,7 @@ func (g *Gateio) GetOrderbook(ctx context.Context, pairString, interval string, } params.Set("with_id", strconv.FormatBool(withOrderbookID)) var response *OrderbookData - err := g.SendHTTPRequest(ctx, exchange.RestSpot, spotDefaultEPL, common.EncodeURLValues(gateioSpotOrderbook, params), &response) + err := g.SendHTTPRequest(ctx, exchange.RestSpot, publicOrderbookSpotEPL, common.EncodeURLValues(gateioSpotOrderbook, params), &response) if err != nil { return nil, err } @@ -452,8 +449,7 @@ func (g *Gateio) GetMarketTrades(ctx context.Context, pairString currency.Pair, params.Set("page", strconv.FormatUint(page, 10)) } var response []Trade - return response, g.SendHTTPRequest(ctx, exchange.RestSpot, spotDefaultEPL, - common.EncodeURLValues(gateioSpotMarketTrades, params), &response) + return response, g.SendHTTPRequest(ctx, exchange.RestSpot, publicMarketTradesSpotEPL, common.EncodeURLValues(gateioSpotMarketTrades, params), &response) } // GetCandlesticks retrieves market candlesticks. @@ -482,7 +478,7 @@ func (g *Gateio) GetCandlesticks(ctx context.Context, currencyPair currency.Pair params.Set("to", strconv.FormatInt(to.Unix(), 10)) } var candles [][7]string - err = g.SendHTTPRequest(ctx, exchange.RestSpot, spotDefaultEPL, common.EncodeURLValues(gateioSpotCandlesticks, params), &candles) + err = g.SendHTTPRequest(ctx, exchange.RestSpot, publicCandleStickSpotEPL, common.EncodeURLValues(gateioSpotCandlesticks, params), &candles) if err != nil { return nil, err } @@ -540,8 +536,7 @@ func (g *Gateio) GetTradingFeeRatio(ctx context.Context, currencyPair currency.P params.Set("currency_pair", currencyPair.String()) } var response *SpotTradingFeeRate - return response, g.SendAuthenticatedHTTPRequest(ctx, - exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioSpotFeeRate, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotTradingFeeEPL, http.MethodGet, gateioSpotFeeRate, params, nil, &response) } // GetSpotAccounts retrieves spot account. @@ -551,8 +546,7 @@ func (g *Gateio) GetSpotAccounts(ctx context.Context, ccy currency.Code) ([]Spot params.Set("currency", ccy.String()) } var response []SpotAccount - return response, g.SendAuthenticatedHTTPRequest(ctx, - exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioSpotAccounts, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotAccountsEPL, http.MethodGet, gateioSpotAccounts, params, nil, &response) } // GetUnifiedAccount retrieves unified account. @@ -562,7 +556,7 @@ func (g *Gateio) GetUnifiedAccount(ctx context.Context, ccy currency.Code) (*Uni params.Set("currency", ccy.String()) } var response UnifiedUserAccount - return &response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioUnifiedAccounts, params, nil, &response) + return &response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, privateUnifiedSpotEPL, http.MethodGet, gateioUnifiedAccounts, params, nil, &response) } // CreateBatchOrders Create a batch of orders Batch orders requirements: custom order field text is required At most 4 currency pairs, @@ -598,14 +592,14 @@ func (g *Gateio) CreateBatchOrders(ctx context.Context, args []CreateOrderReques } } var response []SpotOrder - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPlaceOrdersEPL, http.MethodPost, gateioSpotBatchOrders, nil, &args, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotBatchOrdersEPL, http.MethodPost, gateioSpotBatchOrders, nil, &args, &response) } -// GateioSpotOpenOrders retrieves all open orders +// GetSpotOpenOrders retrieves all open orders // List open orders in all currency pairs. // Note that pagination parameters affect record number in each currency pair's open order list. No pagination is applied to the number of currency pairs returned. All currency pairs with open orders will be returned. // Spot and margin orders are returned by default. To list cross margin orders, account must be set to cross_margin -func (g *Gateio) GateioSpotOpenOrders(ctx context.Context, page, limit uint64, isCrossMargin bool) ([]SpotOrdersDetail, error) { +func (g *Gateio) GetSpotOpenOrders(ctx context.Context, page, limit uint64, isCrossMargin bool) ([]SpotOrdersDetail, error) { params := url.Values{} if page > 0 { params.Set("page", strconv.FormatUint(page, 10)) @@ -617,7 +611,7 @@ func (g *Gateio) GateioSpotOpenOrders(ctx context.Context, page, limit uint64, i params.Set("account", asset.CrossMargin.String()) } var response []SpotOrdersDetail - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioSpotOpenOrders, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotGetOpenOrdersEPL, http.MethodGet, gateioSpotOpenOrders, params, nil, &response) } // SpotClosePositionWhenCrossCurrencyDisabled set close position when cross-currency is disabled @@ -635,8 +629,7 @@ func (g *Gateio) SpotClosePositionWhenCrossCurrencyDisabled(ctx context.Context, return nil, errInvalidPrice } var response *SpotOrder - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, - http.MethodPost, gateioSpotClosePositionWhenCrossCurrencyDisabledPath, nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotClosePositionEPL, http.MethodPost, gateioSpotClosePositionWhenCrossCurrencyDisabledPath, nil, &arg, &response) } // PlaceSpotOrder creates a spot order you can place orders with spot, margin or cross margin account through setting the accountfield. @@ -664,7 +657,7 @@ func (g *Gateio) PlaceSpotOrder(ctx context.Context, arg *CreateOrderRequestData return nil, errInvalidPrice } var response *SpotOrder - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPlaceOrdersEPL, http.MethodPost, gateioSpotOrders, nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPlaceOrderEPL, http.MethodPost, gateioSpotOrders, nil, &arg, &response) } // GetSpotOrders retrieves spot orders. @@ -684,7 +677,7 @@ func (g *Gateio) GetSpotOrders(ctx context.Context, currencyPair currency.Pair, params.Set("limit", strconv.FormatUint(limit, 10)) } var response []SpotOrder - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPlaceOrdersEPL, http.MethodGet, gateioSpotOrders, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotGetOrdersEPL, http.MethodGet, gateioSpotOrders, params, nil, &response) } // CancelAllOpenOrdersSpecifiedCurrencyPair cancel all open orders in specified currency pair @@ -701,8 +694,7 @@ func (g *Gateio) CancelAllOpenOrdersSpecifiedCurrencyPair(ctx context.Context, c params.Set("account", account.String()) } var response []SpotOrder - return response, g.SendAuthenticatedHTTPRequest(ctx, - exchange.RestSpot, spotCancelOrdersEPL, http.MethodDelete, gateioSpotOrders, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotCancelAllOpenOrdersEPL, http.MethodDelete, gateioSpotOrders, params, nil, &response) } // CancelBatchOrdersWithIDList cancels batch orders specifying the order ID and currency pair information @@ -719,7 +711,7 @@ func (g *Gateio) CancelBatchOrdersWithIDList(ctx context.Context, args []CancelO return nil, errors.New("currency pair and order ID are required") } } - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotCancelOrdersEPL, http.MethodPost, gateioSpotCancelBatchOrders, nil, &args, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotCancelBatchOrdersEPL, http.MethodPost, gateioSpotCancelBatchOrders, nil, &args, &response) } // GetSpotOrder retrieves a single spot order using the order id and currency pair information. @@ -736,8 +728,7 @@ func (g *Gateio) GetSpotOrder(ctx context.Context, orderID string, currencyPair params.Set("account", accountType) } var response *SpotOrder - return response, g.SendAuthenticatedHTTPRequest(ctx, - exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioSpotOrders+"/"+orderID, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotGetOrderEPL, http.MethodGet, gateioSpotOrders+"/"+orderID, params, nil, &response) } // AmendSpotOrder amend an order @@ -763,7 +754,7 @@ func (g *Gateio) AmendSpotOrder(ctx context.Context, orderID string, currencyPai return nil, errors.New("only can chose one of amount or price") } var resp *SpotOrder - return resp, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPlaceOrdersEPL, http.MethodPatch, gateioSpotOrders+"/"+orderID, params, arg, &resp) + return resp, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotAmendOrderEPL, http.MethodPatch, gateioSpotOrders+"/"+orderID, params, arg, &resp) } // CancelSingleSpotOrder cancels a single order @@ -782,8 +773,7 @@ func (g *Gateio) CancelSingleSpotOrder(ctx context.Context, orderID, currencyPai params.Set("account", asset.CrossMargin.String()) } var response *SpotOrder - return response, g.SendAuthenticatedHTTPRequest(ctx, - exchange.RestSpot, spotCancelOrdersEPL, http.MethodDelete, gateioSpotOrders+"/"+orderID, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotCancelSingleOrderEPL, http.MethodDelete, gateioSpotOrders+"/"+orderID, params, nil, &response) } // GateIOGetPersonalTradingHistory retrieves personal trading history @@ -812,7 +802,7 @@ func (g *Gateio) GateIOGetPersonalTradingHistory(ctx context.Context, currencyPa params.Set("to", strconv.FormatInt(to.Unix(), 10)) } var response []SpotPersonalTradeHistory - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioSpotMyTrades, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotTradingHistoryEPL, http.MethodGet, gateioSpotMyTrades, params, nil, &response) } // GetServerTime retrieves current server time @@ -820,7 +810,7 @@ func (g *Gateio) GetServerTime(ctx context.Context, _ asset.Item) (time.Time, er resp := struct { ServerTime int64 `json:"server_time"` }{} - err := g.SendHTTPRequest(ctx, exchange.RestSpot, spotDefaultEPL, gateioSpotServerTime, &resp) + err := g.SendHTTPRequest(ctx, exchange.RestSpot, publicGetServerTimeEPL, gateioSpotServerTime, &resp) if err != nil { return time.Time{}, err } @@ -835,7 +825,7 @@ func (g *Gateio) CountdownCancelorders(ctx context.Context, arg CountdownCancelO return nil, errInvalidCountdown } var response *TriggerTimeResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotCancelOrdersEPL, http.MethodPost, gateioSpotAllCountdown, nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotCountdownCancelEPL, http.MethodPost, gateioSpotAllCountdown, nil, &arg, &response) } // CreatePriceTriggeredOrder create a price-triggered order @@ -877,7 +867,7 @@ func (g *Gateio) CreatePriceTriggeredOrder(ctx context.Context, arg *PriceTrigge arg.Put.Account = "normal" } var response *OrderID - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPlaceOrdersEPL, http.MethodPost, gateioSpotPriceOrders, nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotCreateTriggerOrderEPL, http.MethodPost, gateioSpotPriceOrders, nil, &arg, &response) } // GetPriceTriggeredOrderList retrieves price orders created with an order detail and trigger price information. @@ -900,7 +890,7 @@ func (g *Gateio) GetPriceTriggeredOrderList(ctx context.Context, status string, params.Set("offset", strconv.FormatUint(offset, 10)) } var response []SpotPriceTriggeredOrder - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioSpotPriceOrders, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotGetTriggerOrderListEPL, http.MethodGet, gateioSpotPriceOrders, params, nil, &response) } // CancelMultipleSpotOpenOrders deletes price triggered orders. @@ -918,7 +908,7 @@ func (g *Gateio) CancelMultipleSpotOpenOrders(ctx context.Context, currencyPair params.Set("account", account.String()) } var response []SpotPriceTriggeredOrder - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotCancelOrdersEPL, http.MethodDelete, gateioSpotPriceOrders, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotCancelTriggerOrdersEPL, http.MethodDelete, gateioSpotPriceOrders, params, nil, &response) } // GetSinglePriceTriggeredOrder get a single order @@ -927,7 +917,7 @@ func (g *Gateio) GetSinglePriceTriggeredOrder(ctx context.Context, orderID strin return nil, errInvalidOrderID } var response *SpotPriceTriggeredOrder - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioSpotPriceOrders+"/"+orderID, nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotGetTriggerOrderEPL, http.MethodGet, gateioSpotPriceOrders+"/"+orderID, nil, nil, &response) } // CancelPriceTriggeredOrder cancel a price-triggered order @@ -936,7 +926,7 @@ func (g *Gateio) CancelPriceTriggeredOrder(ctx context.Context, orderID string) return nil, errInvalidOrderID } var response *SpotPriceTriggeredOrder - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotCancelOrdersEPL, http.MethodGet, gateioSpotPriceOrders+"/"+orderID, nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotCancelTriggerOrderEPL, http.MethodGet, gateioSpotPriceOrders+"/"+orderID, nil, nil, &response) } // GenerateSignature returns hash for authenticated requests @@ -1070,7 +1060,7 @@ func (g *Gateio) WithdrawCurrency(ctx context.Context, arg WithdrawalRequestPara return nil, errors.New("name of the chain used for withdrawal must be specified") } var response *WithdrawalResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, withdrawalEPL, http.MethodPost, withdrawal, nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletWithdrawEPL, http.MethodPost, withdrawal, nil, &arg, &response) } // CancelWithdrawalWithSpecifiedID cancels withdrawal with specified ID. @@ -1079,7 +1069,7 @@ func (g *Gateio) CancelWithdrawalWithSpecifiedID(ctx context.Context, withdrawal return nil, errMissingWithdrawalID } var response *WithdrawalResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, withdrawalEPL, http.MethodDelete, withdrawal+"/"+withdrawalID, nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletCancelWithdrawEPL, http.MethodDelete, withdrawal+"/"+withdrawalID, nil, nil, &response) } // *********************************** Wallet *********************************** @@ -1092,7 +1082,7 @@ func (g *Gateio) ListCurrencyChain(ctx context.Context, ccy currency.Code) ([]Cu params := url.Values{} params.Set("currency", ccy.String()) var resp []CurrencyChain - return resp, g.SendHTTPRequest(ctx, exchange.RestSpot, walletEPL, common.EncodeURLValues(walletCurrencyChain, params), &resp) + return resp, g.SendHTTPRequest(ctx, exchange.RestSpot, publicListCurrencyChainEPL, common.EncodeURLValues(walletCurrencyChain, params), &resp) } // GenerateCurrencyDepositAddress generate currency deposit address @@ -1103,8 +1093,7 @@ func (g *Gateio) GenerateCurrencyDepositAddress(ctx context.Context, ccy currenc params := url.Values{} params.Set("currency", ccy.String()) var response *CurrencyDepositAddressInfo - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletEPL, - http.MethodGet, walletDepositAddress, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletDepositAddressEPL, http.MethodGet, walletDepositAddress, params, nil, &response) } // GetWithdrawalRecords retrieves withdrawal records. Record time range cannot exceed 30 days @@ -1128,8 +1117,7 @@ func (g *Gateio) GetWithdrawalRecords(ctx context.Context, ccy currency.Code, fr } } var withdrawals []WithdrawalResponse - return withdrawals, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletEPL, - http.MethodGet, walletWithdrawals, params, nil, &withdrawals) + return withdrawals, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletWithdrawalRecordsEPL, http.MethodGet, walletWithdrawals, params, nil, &withdrawals) } // GetDepositRecords retrieves deposit records. Record time range cannot exceed 30 days @@ -1151,8 +1139,7 @@ func (g *Gateio) GetDepositRecords(ctx context.Context, ccy currency.Code, from, } } var depositHistories []DepositRecord - return depositHistories, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletEPL, - http.MethodGet, walletDeposits, params, nil, &depositHistories) + return depositHistories, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletDepositRecordsEPL, http.MethodGet, walletDeposits, params, nil, &depositHistories) } // TransferCurrency Transfer between different accounts. Currently support transfers between the following: @@ -1184,7 +1171,7 @@ func (g *Gateio) TransferCurrency(ctx context.Context, arg *TransferCurrencyPara return nil, errInvalidAmount } var response *TransactionIDResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletEPL, http.MethodPost, walletTransfer, nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletTransferCurrencyEPL, http.MethodPost, walletTransfer, nil, &arg, &response) } func (g *Gateio) assetTypeToString(acc asset.Item) string { @@ -1213,7 +1200,7 @@ func (g *Gateio) SubAccountTransfer(ctx context.Context, arg SubAccountTransferP if arg.SubAccountType != "" && arg.SubAccountType != asset.Spot.String() && arg.SubAccountType != asset.Futures.String() && arg.SubAccountType != asset.CrossMargin.String() { return fmt.Errorf("%v; only %v,%v, and %v are allowed", asset.ErrNotSupported, asset.Spot, asset.Futures, asset.CrossMargin) } - return g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletEPL, http.MethodPost, walletSubAccountTransfer, nil, &arg, nil) + return g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletSubAccountTransferEPL, http.MethodPost, walletSubAccountTransfer, nil, &arg, nil) } // GetSubAccountTransferHistory retrieve transfer records between main and sub accounts. @@ -1241,8 +1228,7 @@ func (g *Gateio) GetSubAccountTransferHistory(ctx context.Context, subAccountUse params.Set("limit", strconv.FormatUint(limit, 10)) } var response []SubAccountTransferResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletEPL, - http.MethodGet, walletSubAccountTransfer, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletSubAccountTransferHistoryEPL, http.MethodGet, walletSubAccountTransfer, params, nil, &response) } // SubAccountTransferToSubAccount performs sub-account transfers to sub-account @@ -1265,7 +1251,7 @@ func (g *Gateio) SubAccountTransferToSubAccount(ctx context.Context, arg *InterS if arg.Amount <= 0 { return errInvalidAmount } - return g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletEPL, http.MethodPost, walletInterSubAccountTransfer, nil, &arg, nil) + return g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletSubAccountToSubAccountTransferEPL, http.MethodPost, walletInterSubAccountTransfer, nil, &arg, nil) } // GetWithdrawalStatus retrieves withdrawal status @@ -1275,7 +1261,7 @@ func (g *Gateio) GetWithdrawalStatus(ctx context.Context, ccy currency.Code) ([] params.Set("currency", ccy.String()) } var response []WithdrawalStatus - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletEPL, http.MethodGet, walletWithdrawStatus, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletWithdrawStatusEPL, http.MethodGet, walletWithdrawStatus, params, nil, &response) } // GetSubAccountBalances retrieve sub account balances @@ -1285,7 +1271,7 @@ func (g *Gateio) GetSubAccountBalances(ctx context.Context, subAccountUserID str params.Set("sub_uid", subAccountUserID) } var response []FuturesSubAccountBalance - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletEPL, http.MethodGet, walletSubAccountBalance, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletSubAccountBalancesEPL, http.MethodGet, walletSubAccountBalance, params, nil, &response) } // GetSubAccountMarginBalances query sub accounts' margin balances @@ -1295,7 +1281,7 @@ func (g *Gateio) GetSubAccountMarginBalances(ctx context.Context, subAccountUser params.Set("sub_uid", subAccountUserID) } var response []SubAccountMarginBalance - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletEPL, http.MethodGet, walletSubAccountMarginBalance, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletSubAccountMarginBalancesEPL, http.MethodGet, walletSubAccountMarginBalance, params, nil, &response) } // GetSubAccountFuturesBalances retrieves sub accounts' futures account balances @@ -1308,7 +1294,7 @@ func (g *Gateio) GetSubAccountFuturesBalances(ctx context.Context, subAccountUse params.Set("settle", settle.Item.Lower) } var response []FuturesSubAccountBalance - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletEPL, http.MethodGet, walletSubAccountFuturesBalance, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletSubAccountFuturesBalancesEPL, http.MethodGet, walletSubAccountFuturesBalance, params, nil, &response) } // GetSubAccountCrossMarginBalances query subaccount's cross_margin account info @@ -1318,7 +1304,7 @@ func (g *Gateio) GetSubAccountCrossMarginBalances(ctx context.Context, subAccoun params.Set("sub_uid", subAccountUserID) } var response []SubAccountCrossMarginInfo - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletEPL, http.MethodGet, walletSubAccountCrossMarginBalances, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletSubAccountCrossMarginBalancesEPL, http.MethodGet, walletSubAccountCrossMarginBalances, params, nil, &response) } // GetSavedAddresses retrieves saved currency address info and related details. @@ -1335,7 +1321,7 @@ func (g *Gateio) GetSavedAddresses(ctx context.Context, ccy currency.Code, chain params.Set("limit", strconv.FormatUint(limit, 10)) } var response []WalletSavedAddress - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletEPL, http.MethodGet, walletSavedAddress, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletSavedAddressesEPL, http.MethodGet, walletSavedAddress, params, nil, &response) } // GetPersonalTradingFee retrieves personal trading fee @@ -1349,7 +1335,7 @@ func (g *Gateio) GetPersonalTradingFee(ctx context.Context, currencyPair currenc params.Set("settle", settle.Item.Lower) } var response *PersonalTradingFee - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletEPL, http.MethodGet, walletTradingFee, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletTradingFeeEPL, http.MethodGet, walletTradingFee, params, nil, &response) } // GetUsersTotalBalance retrieves user's total balances @@ -1359,7 +1345,7 @@ func (g *Gateio) GetUsersTotalBalance(ctx context.Context, ccy currency.Code) (* params.Set("currency", ccy.String()) } var response *UsersAllAccountBalance - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletEPL, http.MethodGet, walletTotalBalance, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, walletTotalBalanceEPL, http.MethodGet, walletTotalBalance, params, nil, &response) } // ********************************* Margin ******************************************* @@ -1367,7 +1353,7 @@ func (g *Gateio) GetUsersTotalBalance(ctx context.Context, ccy currency.Code) (* // GetMarginSupportedCurrencyPairs retrieves margin supported currency pairs. func (g *Gateio) GetMarginSupportedCurrencyPairs(ctx context.Context) ([]MarginCurrencyPairInfo, error) { var currenciePairsInfo []MarginCurrencyPairInfo - return currenciePairsInfo, g.SendHTTPRequest(ctx, exchange.RestSpot, spotDefaultEPL, gateioMarginCurrencyPairs, ¤ciePairsInfo) + return currenciePairsInfo, g.SendHTTPRequest(ctx, exchange.RestSpot, publicCurrencyPairsMarginEPL, gateioMarginCurrencyPairs, ¤ciePairsInfo) } // GetSingleMarginSupportedCurrencyPair retrieves margin supported currency pair detail given the currency pair. @@ -1376,7 +1362,7 @@ func (g *Gateio) GetSingleMarginSupportedCurrencyPair(ctx context.Context, marke return nil, currency.ErrCurrencyPairEmpty } var currencyPairInfo *MarginCurrencyPairInfo - return currencyPairInfo, g.SendHTTPRequest(ctx, exchange.RestSpot, spotDefaultEPL, gateioMarginCurrencyPairs+"/"+market.String(), ¤cyPairInfo) + return currencyPairInfo, g.SendHTTPRequest(ctx, exchange.RestSpot, publicCurrencyPairsMarginEPL, gateioMarginCurrencyPairs+"/"+market.String(), ¤cyPairInfo) } // GetOrderbookOfLendingLoans retrieves order book of lending loans for specific currency @@ -1385,8 +1371,7 @@ func (g *Gateio) GetOrderbookOfLendingLoans(ctx context.Context, ccy currency.Co return nil, currency.ErrCurrencyCodeEmpty } var lendingLoans []OrderbookOfLendingLoan - return lendingLoans, g.SendHTTPRequest(ctx, exchange.RestSpot, spotDefaultEPL, - gateioMarginFundingBook+"?currency="+ccy.String(), &lendingLoans) + return lendingLoans, g.SendHTTPRequest(ctx, exchange.RestSpot, publicOrderbookMarginEPL, gateioMarginFundingBook+"?currency="+ccy.String(), &lendingLoans) } // GetMarginAccountList margin account list @@ -1396,7 +1381,7 @@ func (g *Gateio) GetMarginAccountList(ctx context.Context, currencyPair currency params.Set("currency_pair", currencyPair.String()) } var response []MarginAccountItem - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioMarginAccount, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginAccountListEPL, http.MethodGet, gateioMarginAccount, params, nil, &response) } // ListMarginAccountBalanceChangeHistory retrieves margin account balance change history @@ -1422,7 +1407,7 @@ func (g *Gateio) ListMarginAccountBalanceChangeHistory(ctx context.Context, ccy params.Set("limit", strconv.FormatUint(limit, 10)) } var response []MarginAccountBalanceChangeInfo - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioMarginAccountBook, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginAccountBalanceEPL, http.MethodGet, gateioMarginAccountBook, params, nil, &response) } // GetMarginFundingAccountList retrieves funding account list @@ -1432,7 +1417,7 @@ func (g *Gateio) GetMarginFundingAccountList(ctx context.Context, ccy currency.C params.Set("currency", ccy.String()) } var response []MarginFundingAccountItem - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioMarginFundingAccounts, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginFundingAccountListEPL, http.MethodGet, gateioMarginFundingAccounts, params, nil, &response) } // MarginLoan represents lend or borrow request @@ -1456,7 +1441,7 @@ func (g *Gateio) MarginLoan(ctx context.Context, arg *MarginLoanRequestParam) (* return nil, errors.New("invalid loan rate, rate must be between 0.0002 and 0.002") } var response *MarginLoanResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPlaceOrdersEPL, http.MethodPost, gateioMarginLoans, nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginLendBorrowEPL, http.MethodPost, gateioMarginLoans, nil, &arg, &response) } // GetMarginAllLoans retrieves all loans (borrow and lending) orders. @@ -1490,7 +1475,7 @@ func (g *Gateio) GetMarginAllLoans(ctx context.Context, status, side, sortBy str params.Set("limit", strconv.FormatUint(limit, 10)) } var response []MarginLoanResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioMarginLoans, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginAllLoansEPL, http.MethodGet, gateioMarginLoans, params, nil, &response) } // MergeMultipleLendingLoans merge multiple lending loans @@ -1505,7 +1490,7 @@ func (g *Gateio) MergeMultipleLendingLoans(ctx context.Context, ccy currency.Cod params.Set("currency", ccy.String()) params.Set("ids", strings.Join(ids, ",")) var response *MarginLoanResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPlaceOrdersEPL, http.MethodPost, gateioMarginMergedLoans, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginMergeLendingLoansEPL, http.MethodPost, gateioMarginMergedLoans, params, nil, &response) } // RetriveOneSingleLoanDetail retrieve one single loan detail @@ -1520,7 +1505,7 @@ func (g *Gateio) RetriveOneSingleLoanDetail(ctx context.Context, side, loanID st params := url.Values{} params.Set("side", side) var response *MarginLoanResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioMarginLoans+"/"+loanID+"/", params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginGetLoanEPL, http.MethodGet, gateioMarginLoans+"/"+loanID+"/", params, nil, &response) } // ModifyALoan Modify a loan @@ -1542,7 +1527,7 @@ func (g *Gateio) ModifyALoan(ctx context.Context, loanID string, arg *ModifyLoan return nil, currency.ErrCurrencyPairEmpty } var response *MarginLoanResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPlaceOrdersEPL, http.MethodPatch, gateioMarginLoans+"/"+loanID, nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginModifyLoanEPL, http.MethodPatch, gateioMarginLoans+"/"+loanID, nil, &arg, &response) } // CancelLendingLoan cancels lending loans. only lent loans can be canceled. @@ -1556,7 +1541,7 @@ func (g *Gateio) CancelLendingLoan(ctx context.Context, ccy currency.Code, loanI params := url.Values{} params.Set("currency", ccy.String()) var response *MarginLoanResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPlaceOrdersEPL, http.MethodDelete, gateioMarginLoans+"/"+loanID, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginCancelLoanEPL, http.MethodDelete, gateioMarginLoans+"/"+loanID, params, nil, &response) } // RepayALoan execute a loan repay. @@ -1580,7 +1565,7 @@ func (g *Gateio) RepayALoan(ctx context.Context, loanID string, arg *RepayLoanRe return nil, fmt.Errorf("%w, repay amount for partial repay mode must be greater than 0", errInvalidAmount) } var response *MarginLoanResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPlaceOrdersEPL, http.MethodPost, gateioMarginLoans+"/"+loanID+"/repayment", nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginRepayLoanEPL, http.MethodPost, gateioMarginLoans+"/"+loanID+"/repayment", nil, &arg, &response) } // ListLoanRepaymentRecords retrieves loan repayment records for specified loan ID @@ -1589,7 +1574,7 @@ func (g *Gateio) ListLoanRepaymentRecords(ctx context.Context, loanID string) ([ return nil, fmt.Errorf("%w, %v", errInvalidLoanID, " loan_id is required") } var response []LoanRepaymentRecord - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioMarginLoans+"/"+loanID+"/repayment", nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginListLoansEPL, http.MethodGet, gateioMarginLoans+"/"+loanID+"/repayment", nil, nil, &response) } // ListRepaymentRecordsOfSpecificLoan retrieves repayment records of specific loan @@ -1609,7 +1594,7 @@ func (g *Gateio) ListRepaymentRecordsOfSpecificLoan(ctx context.Context, loanID, params.Set("limit", strconv.FormatUint(limit, 10)) } var response []LoanRecord - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioMarginLoanRecords, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginRepaymentRecordEPL, http.MethodGet, gateioMarginLoanRecords, params, nil, &response) } // GetOneSingleLoanRecord get one single loan record @@ -1623,7 +1608,7 @@ func (g *Gateio) GetOneSingleLoanRecord(ctx context.Context, loanID, loanRecordI params := url.Values{} params.Set("loan_id", loanID) var response *LoanRecord - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioMarginLoanRecords+"/"+loanRecordID, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginSingleRecordEPL, http.MethodGet, gateioMarginLoanRecords+"/"+loanRecordID, params, nil, &response) } // ModifyALoanRecord modify a loan record @@ -1645,7 +1630,7 @@ func (g *Gateio) ModifyALoanRecord(ctx context.Context, loanRecordID string, arg return nil, errInvalidLoanSide } var response *LoanRecord - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPlaceOrdersEPL, http.MethodPatch, gateioMarginLoanRecords+"/"+loanRecordID, nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginModifyLoanRecordEPL, http.MethodPatch, gateioMarginLoanRecords+"/"+loanRecordID, nil, &arg, &response) } // UpdateUsersAutoRepaymentSetting represents update user's auto repayment setting @@ -1659,13 +1644,13 @@ func (g *Gateio) UpdateUsersAutoRepaymentSetting(ctx context.Context, statusOn b params := url.Values{} params.Set("status", statusStr) var response *OnOffStatus - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPlaceOrdersEPL, http.MethodPost, gateioMarginAutoRepay, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginAutoRepayEPL, http.MethodPost, gateioMarginAutoRepay, params, nil, &response) } // GetUserAutoRepaymentSetting retrieve user auto repayment setting func (g *Gateio) GetUserAutoRepaymentSetting(ctx context.Context) (*OnOffStatus, error) { var response *OnOffStatus - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioMarginAutoRepay, nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginGetAutoRepaySettingsEPL, http.MethodGet, gateioMarginAutoRepay, nil, nil, &response) } // GetMaxTransferableAmountForSpecificMarginCurrency get the max transferable amount for a specific margin currency. @@ -1679,7 +1664,7 @@ func (g *Gateio) GetMaxTransferableAmountForSpecificMarginCurrency(ctx context.C } params.Set("currency", ccy.String()) var response *MaxTransferAndLoanAmount - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioMarginTransfer, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginGetMaxTransferEPL, http.MethodGet, gateioMarginTransfer, params, nil, &response) } // GetMaxBorrowableAmountForSpecificMarginCurrency retrieves the max borrowble amount for specific currency @@ -1693,13 +1678,13 @@ func (g *Gateio) GetMaxBorrowableAmountForSpecificMarginCurrency(ctx context.Con } params.Set("currency", ccy.String()) var response *MaxTransferAndLoanAmount - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioMarginBorrowable, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginGetMaxBorrowEPL, http.MethodGet, gateioMarginBorrowable, params, nil, &response) } // CurrencySupportedByCrossMargin currencies supported by cross margin. func (g *Gateio) CurrencySupportedByCrossMargin(ctx context.Context) ([]CrossMarginCurrencies, error) { var response []CrossMarginCurrencies - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioCrossMarginCurrencies, nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginSupportedCurrencyCrossListEPL, http.MethodGet, gateioCrossMarginCurrencies, nil, nil, &response) } // GetCrossMarginSupportedCurrencyDetail retrieve detail of one single currency supported by cross margin @@ -1708,13 +1693,13 @@ func (g *Gateio) GetCrossMarginSupportedCurrencyDetail(ctx context.Context, ccy return nil, currency.ErrCurrencyCodeEmpty } var response *CrossMarginCurrencies - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioCrossMarginCurrencies+"/"+ccy.String(), nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginSupportedCurrencyCrossEPL, http.MethodGet, gateioCrossMarginCurrencies+"/"+ccy.String(), nil, nil, &response) } // GetCrossMarginAccounts retrieve cross margin account func (g *Gateio) GetCrossMarginAccounts(ctx context.Context) (*CrossMarginAccount, error) { var response *CrossMarginAccount - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioCrossMarginAccounts, nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginAccountsEPL, http.MethodGet, gateioCrossMarginAccounts, nil, nil, &response) } // GetCrossMarginAccountChangeHistory retrieve cross margin account change history @@ -1740,7 +1725,7 @@ func (g *Gateio) GetCrossMarginAccountChangeHistory(ctx context.Context, ccy cur params.Set("type", accountChangeType) } var response []CrossMarginAccountHistoryItem - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioCrossMarginAccountBook, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginAccountHistoryEPL, http.MethodGet, gateioCrossMarginAccountBook, params, nil, &response) } // CreateCrossMarginBorrowLoan create a cross margin borrow loan @@ -1753,7 +1738,7 @@ func (g *Gateio) CreateCrossMarginBorrowLoan(ctx context.Context, arg CrossMargi return nil, fmt.Errorf("%w, borrow amount must be greater than 0", errInvalidAmount) } var response CrossMarginLoanResponse - return &response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPlaceOrdersEPL, http.MethodPost, gateioCrossMarginLoans, nil, &arg, &response) + return &response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginCreateCrossBorrowLoanEPL, http.MethodPost, gateioCrossMarginLoans, nil, &arg, &response) } // ExecuteRepayment when the liquidity of the currency is insufficient and the transaction risk is high, the currency will be disabled, @@ -1767,7 +1752,7 @@ func (g *Gateio) ExecuteRepayment(ctx context.Context, arg CurrencyAndAmount) ([ return nil, fmt.Errorf("%w, repay amount must be greater than 0", errInvalidAmount) } var response []CrossMarginLoanResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPlaceOrdersEPL, http.MethodPost, gateioCrossMarginRepayments, nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginExecuteRepaymentsEPL, http.MethodPost, gateioCrossMarginRepayments, nil, &arg, &response) } // GetCrossMarginRepayments retrieves list of cross margin repayments @@ -1789,7 +1774,7 @@ func (g *Gateio) GetCrossMarginRepayments(ctx context.Context, ccy currency.Code params.Set("reverse", "true") } var response []CrossMarginLoanResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioCrossMarginRepayments, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginGetCrossMarginRepaymentsEPL, http.MethodGet, gateioCrossMarginRepayments, params, nil, &response) } // GetMaxTransferableAmountForSpecificCrossMarginCurrency get the max transferable amount for a specific cross margin currency @@ -1800,7 +1785,7 @@ func (g *Gateio) GetMaxTransferableAmountForSpecificCrossMarginCurrency(ctx cont params := url.Values{} var response *CurrencyAndAmount params.Set("currency", ccy.String()) - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioCrossMarginTransferable, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginGetMaxTransferCrossEPL, http.MethodGet, gateioCrossMarginTransferable, params, nil, &response) } // GetMaxBorrowableAmountForSpecificCrossMarginCurrency returns the max borrowable amount for a specific cross margin currency @@ -1811,7 +1796,7 @@ func (g *Gateio) GetMaxBorrowableAmountForSpecificCrossMarginCurrency(ctx contex params := url.Values{} params.Set("currency", ccy.String()) var response *CurrencyAndAmount - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioCrossMarginBorrowable, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginGetMaxBorrowCrossEPL, http.MethodGet, gateioCrossMarginBorrowable, params, nil, &response) } // GetCrossMarginBorrowHistory retrieves cross margin borrow history sorted by creation time in descending order by default. @@ -1835,7 +1820,7 @@ func (g *Gateio) GetCrossMarginBorrowHistory(ctx context.Context, status uint64, params.Set("reverse", strconv.FormatBool(reverse)) } var response []CrossMarginLoanResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioCrossMarginLoans, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginGetCrossBorrowHistoryEPL, http.MethodGet, gateioCrossMarginLoans, params, nil, &response) } // GetSingleBorrowLoanDetail retrieve single borrow loan detail @@ -1844,18 +1829,18 @@ func (g *Gateio) GetSingleBorrowLoanDetail(ctx context.Context, loanID string) ( return nil, errInvalidLoanID } var response *CrossMarginLoanResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioCrossMarginLoans+"/"+loanID, nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, marginGetBorrowEPL, http.MethodGet, gateioCrossMarginLoans+"/"+loanID, nil, nil, &response) } // *********************************Futures*************************************** -// GetAllFutureContracts retrieves list all futures contracts +// GetAllFutureContracts retrieves list all futures contracts func (g *Gateio) GetAllFutureContracts(ctx context.Context, settle currency.Code) ([]FuturesContract, error) { if settle.IsEmpty() { return nil, errEmptyOrInvalidSettlementCurrency } var contracts []FuturesContract - return contracts, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, futuresPath+settle.Item.Lower+"/contracts", &contracts) + return contracts, g.SendHTTPRequest(ctx, exchange.RestSpot, publicFuturesContractsEPL, futuresPath+settle.Item.Lower+"/contracts", &contracts) } // GetSingleContract returns a single contract info for the specified settle and Currency Pair (contract << in this case) @@ -1867,7 +1852,7 @@ func (g *Gateio) GetSingleContract(ctx context.Context, settle currency.Code, co return nil, errEmptyOrInvalidSettlementCurrency } var futureContract *FuturesContract - return futureContract, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, futuresPath+settle.Item.Lower+"/contracts/"+contract, &futureContract) + return futureContract, g.SendHTTPRequest(ctx, exchange.RestSpot, publicFuturesContractsEPL, futuresPath+settle.Item.Lower+"/contracts/"+contract, &futureContract) } // GetFuturesOrderbook retrieves futures order book data @@ -1890,7 +1875,7 @@ func (g *Gateio) GetFuturesOrderbook(ctx context.Context, settle currency.Code, params.Set("with_id", "true") } var response *Orderbook - return response, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, common.EncodeURLValues(futuresPath+settle.Item.Lower+"/order_book", params), &response) + return response, g.SendHTTPRequest(ctx, exchange.RestSpot, publicOrderbookFuturesEPL, common.EncodeURLValues(futuresPath+settle.Item.Lower+"/order_book", params), &response) } // GetFuturesTradingHistory retrieves futures trading history @@ -1919,8 +1904,7 @@ func (g *Gateio) GetFuturesTradingHistory(ctx context.Context, settle currency.C params.Set("to", strconv.FormatInt(to.Unix(), 10)) } var response []TradingHistoryItem - return response, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, - common.EncodeURLValues(futuresPath+settle.Item.Lower+"/trades", params), &response) + return response, g.SendHTTPRequest(ctx, exchange.RestSpot, publicTradingHistoryFuturesEPL, common.EncodeURLValues(futuresPath+settle.Item.Lower+"/trades", params), &response) } // GetFuturesCandlesticks retrieves specified contract candlesticks. @@ -1950,9 +1934,7 @@ func (g *Gateio) GetFuturesCandlesticks(ctx context.Context, settle currency.Cod params.Set("interval", intervalString) } var candlesticks []FuturesCandlestick - return candlesticks, g.SendHTTPRequest(ctx, exchange.RestFutures, perpetualSwapDefaultEPL, - common.EncodeURLValues(futuresPath+settle.Item.Lower+"/candlesticks", params), - &candlesticks) + return candlesticks, g.SendHTTPRequest(ctx, exchange.RestFutures, publicCandleSticksFuturesEPL, common.EncodeURLValues(futuresPath+settle.Item.Lower+"/candlesticks", params), &candlesticks) } // PremiumIndexKLine retrieves premium Index K-Line @@ -1981,7 +1963,7 @@ func (g *Gateio) PremiumIndexKLine(ctx context.Context, settleCurrency currency. } params.Set("interval", intervalString) var resp []FuturesPremiumIndexKLineResponse - return resp, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, common.EncodeURLValues(futuresPath+settleCurrency.Item.Lower+"/premium_index", params), &resp) + return resp, g.SendHTTPRequest(ctx, exchange.RestSpot, publicPremiumIndexEPL, common.EncodeURLValues(futuresPath+settleCurrency.Item.Lower+"/premium_index", params), &resp) } // GetFuturesTickers retrieves futures ticker information for a specific settle and contract info. @@ -1994,7 +1976,7 @@ func (g *Gateio) GetFuturesTickers(ctx context.Context, settle currency.Code, co params.Set("contract", contract.String()) } var tickers []FuturesTicker - return tickers, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, common.EncodeURLValues(futuresPath+settle.Item.Lower+"/tickers", params), &tickers) + return tickers, g.SendHTTPRequest(ctx, exchange.RestSpot, publicTickersFuturesEPL, common.EncodeURLValues(futuresPath+settle.Item.Lower+"/tickers", params), &tickers) } // GetFutureFundingRates retrieves funding rate information. @@ -2011,7 +1993,7 @@ func (g *Gateio) GetFutureFundingRates(ctx context.Context, settle currency.Code params.Set("limit", strconv.FormatUint(limit, 10)) } var rates []FuturesFundingRate - return rates, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, common.EncodeURLValues(futuresPath+settle.Item.Lower+"/funding_rate", params), &rates) + return rates, g.SendHTTPRequest(ctx, exchange.RestSpot, publicFundingRatesEPL, common.EncodeURLValues(futuresPath+settle.Item.Lower+"/funding_rate", params), &rates) } // GetFuturesInsuranceBalanceHistory retrieves futures insurance balance history @@ -2024,10 +2006,7 @@ func (g *Gateio) GetFuturesInsuranceBalanceHistory(ctx context.Context, settle c params.Set("limit", strconv.FormatUint(limit, 10)) } var balances []InsuranceBalance - return balances, g.SendHTTPRequest(ctx, - exchange.RestSpot, perpetualSwapDefaultEPL, - common.EncodeURLValues(futuresPath+settle.Item.Lower+"/insurance", params), - &balances) + return balances, g.SendHTTPRequest(ctx, exchange.RestSpot, publicInsuranceFuturesEPL, common.EncodeURLValues(futuresPath+settle.Item.Lower+"/insurance", params), &balances) } // GetFutureStats retrieves futures stats @@ -2054,10 +2033,7 @@ func (g *Gateio) GetFutureStats(ctx context.Context, settle currency.Code, contr params.Set("limit", strconv.FormatUint(limit, 10)) } var stats []ContractStat - return stats, g.SendHTTPRequest(ctx, - exchange.RestSpot, perpetualSwapDefaultEPL, - common.EncodeURLValues(futuresPath+settle.Item.Lower+"/contract_stats", params), - &stats) + return stats, g.SendHTTPRequest(ctx, exchange.RestSpot, publicStatsFuturesEPL, common.EncodeURLValues(futuresPath+settle.Item.Lower+"/contract_stats", params), &stats) } // GetIndexConstituent retrieves index constituents @@ -2070,10 +2046,7 @@ func (g *Gateio) GetIndexConstituent(ctx context.Context, settle currency.Code, } indexString := strings.ToUpper(index) var constituents *IndexConstituent - return constituents, g.SendHTTPRequest(ctx, - exchange.RestSpot, perpetualSwapDefaultEPL, - futuresPath+settle.Item.Lower+"/index_constituents/"+indexString, - &constituents) + return constituents, g.SendHTTPRequest(ctx, exchange.RestSpot, publicIndexConstituentsEPL, futuresPath+settle.Item.Lower+"/index_constituents/"+indexString, &constituents) } // GetLiquidationHistory retrieves liqudiation history @@ -2096,10 +2069,7 @@ func (g *Gateio) GetLiquidationHistory(ctx context.Context, settle currency.Code params.Set("limit", strconv.FormatUint(limit, 10)) } var histories []LiquidationHistory - return histories, g.SendHTTPRequest(ctx, - exchange.RestSpot, perpetualSwapDefaultEPL, - common.EncodeURLValues(futuresPath+settle.Item.Lower+"/liq_orders", params), - &histories) + return histories, g.SendHTTPRequest(ctx, exchange.RestSpot, publicLiquidationHistoryEPL, common.EncodeURLValues(futuresPath+settle.Item.Lower+"/liq_orders", params), &histories) } // QueryFuturesAccount retrieves futures account @@ -2108,7 +2078,7 @@ func (g *Gateio) QueryFuturesAccount(ctx context.Context, settle currency.Code) return nil, errEmptyOrInvalidSettlementCurrency } var response *FuturesAccount - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, futuresPath+settle.Item.Lower+"/accounts", nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualAccountEPL, http.MethodGet, futuresPath+settle.Item.Lower+"/accounts", nil, nil, &response) } // GetFuturesAccountBooks retrieves account books @@ -2130,11 +2100,7 @@ func (g *Gateio) GetFuturesAccountBooks(ctx context.Context, settle currency.Cod params.Set("type", changingType) } var response []AccountBookItem - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, - http.MethodGet, futuresPath+settle.Item.Lower+"/account_book", - params, - nil, - &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualAccountBooksEPL, http.MethodGet, futuresPath+settle.Item.Lower+"/account_book", params, nil, &response) } // GetAllFuturesPositionsOfUsers list all positions of users. @@ -2147,7 +2113,7 @@ func (g *Gateio) GetAllFuturesPositionsOfUsers(ctx context.Context, settle curre params.Set("holding", "true") } var response []Position - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, futuresPath+settle.Item.Lower+"/positions", params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualPositionsEPL, http.MethodGet, futuresPath+settle.Item.Lower+"/positions", params, nil, &response) } // GetSinglePosition returns a single position @@ -2159,9 +2125,7 @@ func (g *Gateio) GetSinglePosition(ctx context.Context, settle currency.Code, co return nil, fmt.Errorf("%w, currency pair for contract must not be empty", errInvalidOrMissingContractParam) } var response *Position - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, - http.MethodPost, futuresPath+settle.Item.Lower+positionsPath+contract.String(), - nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualPositionEPL, http.MethodPost, futuresPath+settle.Item.Lower+positionsPath+contract.String(), nil, nil, &response) } // UpdateFuturesPositionMargin represents account position margin for a futures contract. @@ -2178,10 +2142,7 @@ func (g *Gateio) UpdateFuturesPositionMargin(ctx context.Context, settle currenc params := url.Values{} params.Set("change", strconv.FormatFloat(change, 'f', -1, 64)) var response *Position - return response, g.SendAuthenticatedHTTPRequest(ctx, - exchange.RestSpot, perpetualSwapPlaceOrdersEPL, - http.MethodPost, futuresPath+settle.Item.Lower+positionsPath+contract.String()+"/margin", - params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualUpdateMarginEPL, http.MethodPost, futuresPath+settle.Item.Lower+positionsPath+contract.String()+"/margin", params, nil, &response) } // UpdateFuturesPositionLeverage update position leverage @@ -2201,9 +2162,7 @@ func (g *Gateio) UpdateFuturesPositionLeverage(ctx context.Context, settle curre params.Set("cross_leverage_limit", strconv.FormatFloat(crossLeverageLimit, 'f', -1, 64)) } var response *Position - return response, g.SendAuthenticatedHTTPRequest(ctx, - exchange.RestSpot, perpetualSwapPlaceOrdersEPL, http.MethodPost, - futuresPath+settle.Item.Lower+positionsPath+contract.String()+"/leverage", params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualUpdateLeverageEPL, http.MethodPost, futuresPath+settle.Item.Lower+positionsPath+contract.String()+"/leverage", params, nil, &response) } // UpdateFuturesPositionRiskLimit updates the position risk limit @@ -2217,8 +2176,7 @@ func (g *Gateio) UpdateFuturesPositionRiskLimit(ctx context.Context, settle curr params := url.Values{} params.Set("risk_limit", strconv.FormatUint(riskLimit, 10)) var response *Position - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPlaceOrdersEPL, - http.MethodPost, futuresPath+settle.Item.Lower+positionsPath+contract.String()+"/risk_limit", params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualUpdateRiskEPL, http.MethodPost, futuresPath+settle.Item.Lower+positionsPath+contract.String()+"/risk_limit", params, nil, &response) } // EnableOrDisableDualMode enable or disable dual mode @@ -2230,9 +2188,7 @@ func (g *Gateio) EnableOrDisableDualMode(ctx context.Context, settle currency.Co params := url.Values{} params.Set("dual_mode", strconv.FormatBool(dualMode)) var response *DualModeResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, - http.MethodGet, futuresPath+settle.Item.Lower+"/dual_mode", - params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualToggleDualModeEPL, http.MethodGet, futuresPath+settle.Item.Lower+"/dual_mode", params, nil, &response) } // RetrivePositionDetailInDualMode retrieve position detail in dual mode @@ -2244,10 +2200,7 @@ func (g *Gateio) RetrivePositionDetailInDualMode(ctx context.Context, settle cur return nil, fmt.Errorf("%w, currency pair for contract must not be empty", errInvalidOrMissingContractParam) } var response []Position - return response, g.SendAuthenticatedHTTPRequest(ctx, - exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, - futuresPath+settle.Item.Lower+"/dual_comp/positions/"+contract.String(), - nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualPositionsDualModeEPL, http.MethodGet, futuresPath+settle.Item.Lower+"/dual_comp/positions/"+contract.String(), nil, nil, &response) } // UpdatePositionMarginInDualMode update position margin in dual mode @@ -2265,10 +2218,7 @@ func (g *Gateio) UpdatePositionMarginInDualMode(ctx context.Context, settle curr } params.Set("dual_side", dualSide) var response []Position - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPlaceOrdersEPL, - http.MethodPost, - futuresPath+settle.Item.Lower+"/dual_comp/positions/"+contract.String()+"/margin", - params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualUpdateMarginDualModeEPL, http.MethodPost, futuresPath+settle.Item.Lower+"/dual_comp/positions/"+contract.String()+"/margin", params, nil, &response) } // UpdatePositionLeverageInDualMode update position leverage in dual mode @@ -2288,7 +2238,7 @@ func (g *Gateio) UpdatePositionLeverageInDualMode(ctx context.Context, settle cu params.Set("cross_leverage_limit", strconv.FormatFloat(crossLeverageLimit, 'f', -1, 64)) } var response *Position - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPlaceOrdersEPL, http.MethodPost, futuresPath+settle.Item.Lower+"/dual_comp/positions/"+contract.String()+"/leverage", params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualUpdateLeverageDualModeEPL, http.MethodPost, futuresPath+settle.Item.Lower+"/dual_comp/positions/"+contract.String()+"/leverage", params, nil, &response) } // UpdatePositionRiskLimitInDualMode update position risk limit in dual mode @@ -2305,10 +2255,7 @@ func (g *Gateio) UpdatePositionRiskLimitInDualMode(ctx context.Context, settle c params := url.Values{} params.Set("risk_limit", strconv.FormatFloat(riskLimit, 'f', -1, 64)) var response []Position - return response, g.SendAuthenticatedHTTPRequest(ctx, - exchange.RestSpot, perpetualSwapPlaceOrdersEPL, http.MethodPost, - futuresPath+settle.Item.Lower+"/dual_comp/positions/"+contract.String()+"/risk_limit", params, - nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualUpdateRiskDualModeEPL, http.MethodPost, futuresPath+settle.Item.Lower+"/dual_comp/positions/"+contract.String()+"/risk_limit", params, nil, &response) } // PlaceFuturesOrder creates futures order @@ -2345,14 +2292,7 @@ func (g *Gateio) PlaceFuturesOrder(ctx context.Context, arg *OrderCreateParams) } var response *Order - return response, g.SendAuthenticatedHTTPRequest(ctx, - exchange.RestSpot, - perpetualSwapPlaceOrdersEPL, - http.MethodPost, - futuresPath+arg.Settle.Item.Lower+ordersPath, - nil, - &arg, - &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSubmitOrderEPL, http.MethodPost, futuresPath+arg.Settle.Item.Lower+ordersPath, nil, &arg, &response) } // GetFuturesOrders retrieves list of futures orders @@ -2384,9 +2324,7 @@ func (g *Gateio) GetFuturesOrders(ctx context.Context, contract currency.Pair, s return nil, errInvalidCountTotalValue } var response []Order - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, - http.MethodGet, futuresPath+settle.Item.Lower+ordersPath, - params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualGetOrdersEPL, http.MethodGet, futuresPath+settle.Item.Lower+ordersPath, params, nil, &response) } // CancelMultipleFuturesOpenOrders ancel all open orders @@ -2404,8 +2342,7 @@ func (g *Gateio) CancelMultipleFuturesOpenOrders(ctx context.Context, contract c } params.Set("contract", contract.String()) var response []Order - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapCancelOrdersEPL, - http.MethodDelete, futuresPath+settle.Item.Lower+ordersPath, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualGetOrdersEPL, http.MethodDelete, futuresPath+settle.Item.Lower+ordersPath, params, nil, &response) } // PlaceBatchFuturesOrders creates a list of futures orders @@ -2450,9 +2387,7 @@ func (g *Gateio) PlaceBatchFuturesOrders(ctx context.Context, settle currency.Co } } var response []Order - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, - http.MethodPost, futuresPath+settle.Item.Lower+"/batch_orders", - nil, &args, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSubmitBatchOrdersEPL, http.MethodPost, futuresPath+settle.Item.Lower+"/batch_orders", nil, &args, &response) } // GetSingleFuturesOrder retrieves a single order by its identifier @@ -2464,10 +2399,7 @@ func (g *Gateio) GetSingleFuturesOrder(ctx context.Context, settle currency.Code return nil, fmt.Errorf("%w, 'order_id' cannot be empty", errInvalidOrderID) } var response *Order - return response, g.SendAuthenticatedHTTPRequest(ctx, - exchange.RestSpot, perpetualSwapPrivateEPL, - http.MethodGet, futuresPath+settle.Item.Lower+"/orders/"+orderID, - nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualFetchOrderEPL, http.MethodGet, futuresPath+settle.Item.Lower+"/orders/"+orderID, nil, nil, &response) } // CancelSingleFuturesOrder cancel a single order @@ -2479,8 +2411,7 @@ func (g *Gateio) CancelSingleFuturesOrder(ctx context.Context, settle currency.C return nil, fmt.Errorf("%w, 'order_id' cannot be empty", errInvalidOrderID) } var response *Order - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapCancelOrdersEPL, http.MethodDelete, - futuresPath+settle.Item.Lower+"/orders/"+orderID, nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualCancelOrderEPL, http.MethodDelete, futuresPath+settle.Item.Lower+"/orders/"+orderID, nil, nil, &response) } // AmendFuturesOrder amends an existing futures order @@ -2495,8 +2426,7 @@ func (g *Gateio) AmendFuturesOrder(ctx context.Context, settle currency.Code, or return nil, errors.New("missing update 'size' or 'price', please specify 'size' or 'price' or both information") } var response *Order - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPlaceOrdersEPL, http.MethodPut, - futuresPath+settle.Item.Lower+"/orders/"+orderID, nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualAmendOrderEPL, http.MethodPut, futuresPath+settle.Item.Lower+"/orders/"+orderID, nil, &arg, &response) } // GetMyPersonalTradingHistory retrieves my personal trading history @@ -2524,8 +2454,7 @@ func (g *Gateio) GetMyPersonalTradingHistory(ctx context.Context, settle currenc params.Set("count_total", strconv.FormatUint(countTotal, 10)) } var response []TradingHistoryItem - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, - futuresPath+settle.Item.Lower+"/my_trades", params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualTradingHistoryEPL, http.MethodGet, futuresPath+settle.Item.Lower+"/my_trades", params, nil, &response) } // GetFuturesPositionCloseHistory lists position close history @@ -2550,8 +2479,7 @@ func (g *Gateio) GetFuturesPositionCloseHistory(ctx context.Context, settle curr params.Set("to", strconv.FormatInt(to.Unix(), 10)) } var response []PositionCloseHistoryResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, - futuresPath+settle.Item.Lower+"/position_close", params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualClosePositionEPL, http.MethodGet, futuresPath+settle.Item.Lower+"/position_close", params, nil, &response) } // GetFuturesLiquidationHistory list liquidation history @@ -2570,9 +2498,7 @@ func (g *Gateio) GetFuturesLiquidationHistory(ctx context.Context, settle curren params.Set("at", strconv.FormatInt(at.Unix(), 10)) } var response []LiquidationHistoryItem - return response, g.SendAuthenticatedHTTPRequest(ctx, - exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, - futuresPath+settle.Item.Lower+"/liquidates", params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualLiquidationHistoryEPL, http.MethodGet, futuresPath+settle.Item.Lower+"/liquidates", params, nil, &response) } // CountdownCancelOrders represents a trigger time response @@ -2584,9 +2510,7 @@ func (g *Gateio) CountdownCancelOrders(ctx context.Context, settle currency.Code return nil, errInvalidTimeout } var response *TriggerTimeResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, - exchange.RestSpot, perpetualSwapPlaceOrdersEPL, http.MethodPost, - futuresPath+settle.Item.Lower+"/countdown_cancel_all", nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualCancelTriggerOrdersEPL, http.MethodPost, futuresPath+settle.Item.Lower+"/countdown_cancel_all", nil, &arg, &response) } // CreatePriceTriggeredFuturesOrder create a price-triggered order @@ -2625,8 +2549,7 @@ func (g *Gateio) CreatePriceTriggeredFuturesOrder(ctx context.Context, settle cu return nil, errors.New("invalid order type, only 'close-long-order', 'close-short-order', 'close-long-position', 'close-short-position', 'plan-close-long-position', and 'plan-close-short-position'") } var response *OrderID - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPlaceOrdersEPL, http.MethodPost, - futuresPath+settle.Item.Lower+priceOrdersPaths, nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSubmitTriggerOrderEPL, http.MethodPost, futuresPath+settle.Item.Lower+priceOrdersPaths, nil, &arg, &response) } // ListAllFuturesAutoOrders lists all open orders @@ -2649,10 +2572,7 @@ func (g *Gateio) ListAllFuturesAutoOrders(ctx context.Context, status string, se params.Set("contract", contract.String()) } var response []PriceTriggeredOrder - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, - http.MethodGet, - futuresPath+settle.Item.Lower+priceOrdersPaths, - params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualListOpenOrdersEPL, http.MethodGet, futuresPath+settle.Item.Lower+priceOrdersPaths, params, nil, &response) } // CancelAllFuturesOpenOrders cancels all futures open orders @@ -2666,8 +2586,7 @@ func (g *Gateio) CancelAllFuturesOpenOrders(ctx context.Context, settle currency params := url.Values{} params.Set("contract", contract.String()) var response []PriceTriggeredOrder - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapCancelOrdersEPL, http.MethodDelete, - futuresPath+settle.Item.Lower+priceOrdersPaths, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualCancelOpenOrdersEPL, http.MethodDelete, futuresPath+settle.Item.Lower+priceOrdersPaths, params, nil, &response) } // GetSingleFuturesPriceTriggeredOrder retrieves a single price triggered order @@ -2679,9 +2598,7 @@ func (g *Gateio) GetSingleFuturesPriceTriggeredOrder(ctx context.Context, settle return nil, errInvalidOrderID } var response *PriceTriggeredOrder - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, - http.MethodGet, - futuresPath+settle.Item.Lower+"/price_orders/"+orderID, nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualGetTriggerOrderEPL, http.MethodGet, futuresPath+settle.Item.Lower+"/price_orders/"+orderID, nil, nil, &response) } // CancelFuturesPriceTriggeredOrder cancel a price-triggered order @@ -2693,8 +2610,7 @@ func (g *Gateio) CancelFuturesPriceTriggeredOrder(ctx context.Context, settle cu return nil, errInvalidOrderID } var response *PriceTriggeredOrder - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapCancelOrdersEPL, http.MethodDelete, - futuresPath+settle.Item.Lower+"/price_orders/"+orderID, nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualCancelTriggerOrderEPL, http.MethodDelete, futuresPath+settle.Item.Lower+"/price_orders/"+orderID, nil, nil, &response) } // *************************************** Delivery *************************************** @@ -2705,8 +2621,7 @@ func (g *Gateio) GetAllDeliveryContracts(ctx context.Context, settle currency.Co return nil, errEmptyOrInvalidSettlementCurrency } var contracts []DeliveryContract - return contracts, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, - deliveryPath+settle.Item.Lower+"/contracts", &contracts) + return contracts, g.SendHTTPRequest(ctx, exchange.RestSpot, publicDeliveryContractsEPL, deliveryPath+settle.Item.Lower+"/contracts", &contracts) } // GetSingleDeliveryContracts retrieves a single delivery contract instance. @@ -2715,8 +2630,7 @@ func (g *Gateio) GetSingleDeliveryContracts(ctx context.Context, settle currency return nil, errEmptyOrInvalidSettlementCurrency } var deliveryContract *DeliveryContract - return deliveryContract, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, - deliveryPath+settle.Item.Lower+"/contracts/"+contract.String(), &deliveryContract) + return deliveryContract, g.SendHTTPRequest(ctx, exchange.RestSpot, publicDeliveryContractsEPL, deliveryPath+settle.Item.Lower+"/contracts/"+contract.String(), &deliveryContract) } // GetDeliveryOrderbook delivery orderbook @@ -2739,7 +2653,7 @@ func (g *Gateio) GetDeliveryOrderbook(ctx context.Context, settle currency.Code, params.Set("with_id", strconv.FormatBool(withOrderbookID)) } var orderbook *Orderbook - return orderbook, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, common.EncodeURLValues(deliveryPath+settle.Item.Lower+"/order_book", params), &orderbook) + return orderbook, g.SendHTTPRequest(ctx, exchange.RestSpot, publicOrderbookDeliveryEPL, common.EncodeURLValues(deliveryPath+settle.Item.Lower+"/order_book", params), &orderbook) } // GetDeliveryTradingHistory retrieves futures trading history @@ -2765,8 +2679,7 @@ func (g *Gateio) GetDeliveryTradingHistory(ctx context.Context, settle currency. params.Set("last_id", lastID) } var histories []DeliveryTradingHistory - return histories, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, - common.EncodeURLValues(deliveryPath+settle.Item.Lower+"/trades", params), &histories) + return histories, g.SendHTTPRequest(ctx, exchange.RestSpot, publicTradingHistoryDeliveryEPL, common.EncodeURLValues(deliveryPath+settle.Item.Lower+"/trades", params), &histories) } // GetDeliveryFuturesCandlesticks retrieves specified contract candlesticks @@ -2796,10 +2709,7 @@ func (g *Gateio) GetDeliveryFuturesCandlesticks(ctx context.Context, settle curr params.Set("interval", intervalString) } var candlesticks []FuturesCandlestick - return candlesticks, g.SendHTTPRequest(ctx, - exchange.RestSpot, perpetualSwapDefaultEPL, - common.EncodeURLValues(deliveryPath+settle.Item.Lower+"/candlesticks", params), - &candlesticks) + return candlesticks, g.SendHTTPRequest(ctx, exchange.RestSpot, publicCandleSticksDeliveryEPL, common.EncodeURLValues(deliveryPath+settle.Item.Lower+"/candlesticks", params), &candlesticks) } // GetDeliveryFutureTickers retrieves futures ticker information for a specific settle and contract info. @@ -2812,7 +2722,7 @@ func (g *Gateio) GetDeliveryFutureTickers(ctx context.Context, settle currency.C params.Set("contract", contract.String()) } var tickers []FuturesTicker - return tickers, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, common.EncodeURLValues(deliveryPath+settle.Item.Lower+"/tickers", params), &tickers) + return tickers, g.SendHTTPRequest(ctx, exchange.RestSpot, publicTickersDeliveryEPL, common.EncodeURLValues(deliveryPath+settle.Item.Lower+"/tickers", params), &tickers) } // GetDeliveryInsuranceBalanceHistory retrieves delivery futures insurance balance history @@ -2825,9 +2735,7 @@ func (g *Gateio) GetDeliveryInsuranceBalanceHistory(ctx context.Context, settle params.Set("limit", strconv.FormatUint(limit, 10)) } var balances []InsuranceBalance - return balances, g.SendHTTPRequest(ctx, exchange.RestSpot, spotDefaultEPL, - common.EncodeURLValues(deliveryPath+settle.Item.Lower+"/insurance", params), - &balances) + return balances, g.SendHTTPRequest(ctx, exchange.RestSpot, publicInsuranceDeliveryEPL, common.EncodeURLValues(deliveryPath+settle.Item.Lower+"/insurance", params), &balances) } // GetDeliveryFuturesAccounts retrieves futures account @@ -2836,7 +2744,7 @@ func (g *Gateio) GetDeliveryFuturesAccounts(ctx context.Context, settle currency return nil, errEmptyOrInvalidSettlementCurrency } var response *FuturesAccount - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, deliveryPath+settle.Item.Lower+"/accounts", nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliveryAccountEPL, http.MethodGet, deliveryPath+settle.Item.Lower+"/accounts", nil, nil, &response) } // GetDeliveryAccountBooks retrieves account books @@ -2858,10 +2766,7 @@ func (g *Gateio) GetDeliveryAccountBooks(ctx context.Context, settle currency.Co params.Set("type", changingType) } var response []AccountBookItem - return response, g.SendAuthenticatedHTTPRequest(ctx, - exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, - deliveryPath+settle.Item.Lower+"/account_book", - params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliveryAccountBooksEPL, http.MethodGet, deliveryPath+settle.Item.Lower+"/account_book", params, nil, &response) } // GetAllDeliveryPositionsOfUser retrieves all positions of user @@ -2870,8 +2775,7 @@ func (g *Gateio) GetAllDeliveryPositionsOfUser(ctx context.Context, settle curre return nil, errEmptyOrInvalidSettlementCurrency } var response *Position - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, - deliveryPath+settle.Item.Lower+"/positions", nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliveryPositionsEPL, http.MethodGet, deliveryPath+settle.Item.Lower+"/positions", nil, nil, &response) } // GetSingleDeliveryPosition get single position @@ -2883,9 +2787,7 @@ func (g *Gateio) GetSingleDeliveryPosition(ctx context.Context, settle currency. return nil, fmt.Errorf("%w, currency pair for contract must not be empty", errInvalidOrMissingContractParam) } var response *Position - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, - deliveryPath+settle.Item.Lower+positionsPath+contract.String(), - nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliveryPositionsEPL, http.MethodGet, deliveryPath+settle.Item.Lower+positionsPath+contract.String(), nil, nil, &response) } // UpdateDeliveryPositionMargin updates position margin @@ -2902,8 +2804,7 @@ func (g *Gateio) UpdateDeliveryPositionMargin(ctx context.Context, settle curren params := url.Values{} params.Set("change", strconv.FormatFloat(change, 'f', -1, 64)) var response *Position - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPlaceOrdersEPL, http.MethodPost, - deliveryPath+settle.Item.Lower+positionsPath+contract.String()+"/margin", params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliveryUpdateMarginEPL, http.MethodPost, deliveryPath+settle.Item.Lower+positionsPath+contract.String()+"/margin", params, nil, &response) } // UpdateDeliveryPositionLeverage updates position leverage @@ -2921,9 +2822,7 @@ func (g *Gateio) UpdateDeliveryPositionLeverage(ctx context.Context, settle curr params.Set("leverage", strconv.FormatFloat(leverage, 'f', -1, 64)) var response *Position return response, g.SendAuthenticatedHTTPRequest(ctx, - exchange.RestSpot, perpetualSwapPlaceOrdersEPL, http.MethodPost, - deliveryPath+settle.Item.Lower+positionsPath+contract.String()+"/leverage", - params, nil, &response) + exchange.RestSpot, deliveryUpdateLeverageEPL, http.MethodPost, deliveryPath+settle.Item.Lower+positionsPath+contract.String()+"/leverage", params, nil, &response) } // UpdateDeliveryPositionRiskLimit update position risk limit @@ -2937,8 +2836,7 @@ func (g *Gateio) UpdateDeliveryPositionRiskLimit(ctx context.Context, settle cur params := url.Values{} params.Set("risk_limit", strconv.FormatUint(riskLimit, 10)) var response *Position - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPlaceOrdersEPL, http.MethodPost, - deliveryPath+settle.Item.Lower+positionsPath+contract.String()+"/risk_limit", params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliveryUpdateRiskLimitEPL, http.MethodPost, deliveryPath+settle.Item.Lower+positionsPath+contract.String()+"/risk_limit", params, nil, &response) } // PlaceDeliveryOrder create a futures order @@ -2966,8 +2864,7 @@ func (g *Gateio) PlaceDeliveryOrder(ctx context.Context, arg *OrderCreateParams) return nil, errEmptyOrInvalidSettlementCurrency } var response *Order - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPlaceOrdersEPL, http.MethodPost, - deliveryPath+arg.Settle.Item.Lower+ordersPath, nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliverySubmitOrderEPL, http.MethodPost, deliveryPath+arg.Settle.Item.Lower+ordersPath, nil, &arg, &response) } // GetDeliveryOrders list futures orders @@ -2999,8 +2896,7 @@ func (g *Gateio) GetDeliveryOrders(ctx context.Context, contract currency.Pair, return nil, errInvalidCountTotalValue } var response []Order - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, - deliveryPath+settle.Item.Lower+ordersPath, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliveryGetOrdersEPL, http.MethodGet, deliveryPath+settle.Item.Lower+ordersPath, params, nil, &response) } // CancelMultipleDeliveryOrders cancel all open orders matched @@ -3018,8 +2914,7 @@ func (g *Gateio) CancelMultipleDeliveryOrders(ctx context.Context, contract curr } params.Set("contract", contract.String()) var response []Order - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapCancelOrdersEPL, http.MethodDelete, - deliveryPath+settle.Item.Lower+ordersPath, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliveryCancelOrdersEPL, http.MethodDelete, deliveryPath+settle.Item.Lower+ordersPath, params, nil, &response) } // GetSingleDeliveryOrder Get a single order @@ -3032,8 +2927,7 @@ func (g *Gateio) GetSingleDeliveryOrder(ctx context.Context, settle currency.Cod return nil, fmt.Errorf("%w, 'order_id' cannot be empty", errInvalidOrderID) } var response *Order - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, - deliveryPath+settle.Item.Lower+"/orders/"+orderID, nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliveryGetOrderEPL, http.MethodGet, deliveryPath+settle.Item.Lower+"/orders/"+orderID, nil, nil, &response) } // CancelSingleDeliveryOrder cancel a single order @@ -3045,8 +2939,7 @@ func (g *Gateio) CancelSingleDeliveryOrder(ctx context.Context, settle currency. return nil, fmt.Errorf("%w, 'order_id' cannot be empty", errInvalidOrderID) } var response *Order - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapCancelOrdersEPL, http.MethodDelete, - deliveryPath+settle.Item.Lower+"/orders/"+orderID, nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliveryCancelOrderEPL, http.MethodDelete, deliveryPath+settle.Item.Lower+"/orders/"+orderID, nil, nil, &response) } // GetDeliveryPersonalTradingHistory retrieves personal trading history @@ -3074,8 +2967,7 @@ func (g *Gateio) GetDeliveryPersonalTradingHistory(ctx context.Context, settle c params.Set("count_total", strconv.FormatUint(countTotal, 10)) } var response []TradingHistoryItem - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, - deliveryPath+settle.Item.Lower+"/my_trades", params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliveryTradingHistoryEPL, http.MethodGet, deliveryPath+settle.Item.Lower+"/my_trades", params, nil, &response) } // GetDeliveryPositionCloseHistory retrieves position history @@ -3100,8 +2992,7 @@ func (g *Gateio) GetDeliveryPositionCloseHistory(ctx context.Context, settle cur params.Set("to", strconv.FormatInt(to.Unix(), 10)) } var response []PositionCloseHistoryResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, - deliveryPath+settle.Item.Lower+"/position_close", params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliveryCloseHistoryEPL, http.MethodGet, deliveryPath+settle.Item.Lower+"/position_close", params, nil, &response) } // GetDeliveryLiquidationHistory lists liquidation history @@ -3120,8 +3011,7 @@ func (g *Gateio) GetDeliveryLiquidationHistory(ctx context.Context, settle curre params.Set("at", strconv.FormatInt(at.Unix(), 10)) } var response []LiquidationHistoryItem - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, - deliveryPath+settle.Item.Lower+"/liquidates", params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliveryLiquidationHistoryEPL, http.MethodGet, deliveryPath+settle.Item.Lower+"/liquidates", params, nil, &response) } // GetDeliverySettlementHistory retrieves settlement history @@ -3140,8 +3030,7 @@ func (g *Gateio) GetDeliverySettlementHistory(ctx context.Context, settle curren params.Set("at", strconv.FormatInt(at.Unix(), 10)) } var response []SettlementHistoryItem - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, - deliveryPath+settle.Item.Lower+"/settlements", params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliverySettlementHistoryEPL, http.MethodGet, deliveryPath+settle.Item.Lower+"/settlements", params, nil, &response) } // GetDeliveryPriceTriggeredOrder creates a price-triggered order @@ -3187,8 +3076,7 @@ func (g *Gateio) GetDeliveryPriceTriggeredOrder(ctx context.Context, settle curr return nil, errors.New("invalid order type, only 'close-long-order', 'close-short-order', 'close-long-position', 'close-short-position', 'plan-close-long-position', and 'plan-close-short-position'") } var response *OrderID - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodPost, - deliveryPath+settle.Item.Lower+priceOrdersPaths, nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliveryGetTriggerOrderEPL, http.MethodPost, deliveryPath+settle.Item.Lower+priceOrdersPaths, nil, &arg, &response) } // GetDeliveryAllAutoOrder retrieves all auto orders @@ -3211,8 +3099,7 @@ func (g *Gateio) GetDeliveryAllAutoOrder(ctx context.Context, status string, set params.Set("contract", contract.String()) } var response []PriceTriggeredOrder - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, - deliveryPath+settle.Item.Lower+priceOrdersPaths, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliveryAutoOrdersEPL, http.MethodGet, deliveryPath+settle.Item.Lower+priceOrdersPaths, params, nil, &response) } // CancelAllDeliveryPriceTriggeredOrder cancels all delivery price triggered orders @@ -3226,8 +3113,7 @@ func (g *Gateio) CancelAllDeliveryPriceTriggeredOrder(ctx context.Context, settl params := url.Values{} params.Set("contract", contract.String()) var response []PriceTriggeredOrder - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapCancelOrdersEPL, http.MethodDelete, - deliveryPath+settle.Item.Lower+priceOrdersPaths, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliveryCancelTriggerOrdersEPL, http.MethodDelete, deliveryPath+settle.Item.Lower+priceOrdersPaths, params, nil, &response) } // GetSingleDeliveryPriceTriggeredOrder retrieves a single price triggered order @@ -3239,8 +3125,7 @@ func (g *Gateio) GetSingleDeliveryPriceTriggeredOrder(ctx context.Context, settl return nil, errInvalidOrderID } var response *PriceTriggeredOrder - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, - deliveryPath+settle.Item.Lower+"/price_orders/"+orderID, nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliveryGetTriggerOrderEPL, http.MethodGet, deliveryPath+settle.Item.Lower+"/price_orders/"+orderID, nil, nil, &response) } // CancelDeliveryPriceTriggeredOrder cancel a price-triggered order @@ -3252,8 +3137,7 @@ func (g *Gateio) CancelDeliveryPriceTriggeredOrder(ctx context.Context, settle c return nil, errInvalidOrderID } var response *PriceTriggeredOrder - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapCancelOrdersEPL, http.MethodDelete, - deliveryPath+settle.Item.Lower+"/price_orders/"+orderID, nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, deliveryCancelTriggerOrderEPL, http.MethodDelete, deliveryPath+settle.Item.Lower+"/price_orders/"+orderID, nil, nil, &response) } // ********************************** Options *************************************************** @@ -3261,13 +3145,13 @@ func (g *Gateio) CancelDeliveryPriceTriggeredOrder(ctx context.Context, settle c // GetAllOptionsUnderlyings retrieves all option underlyings func (g *Gateio) GetAllOptionsUnderlyings(ctx context.Context) ([]OptionUnderlying, error) { var response []OptionUnderlying - return response, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, gateioOptionUnderlyings, &response) + return response, g.SendHTTPRequest(ctx, exchange.RestSpot, publicUnderlyingOptionsEPL, gateioOptionUnderlyings, &response) } // GetExpirationTime return the expiration time for the provided underlying. func (g *Gateio) GetExpirationTime(ctx context.Context, underlying string) (time.Time, error) { var timestamp []float64 - err := g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, gateioOptionExpiration+"?underlying="+underlying, ×tamp) + err := g.SendHTTPRequest(ctx, exchange.RestSpot, publicExpirationOptionsEPL, gateioOptionExpiration+"?underlying="+underlying, ×tamp) if err != nil { return time.Time{}, err } @@ -3288,7 +3172,7 @@ func (g *Gateio) GetAllContractOfUnderlyingWithinExpiryDate(ctx context.Context, params.Set("expires", strconv.FormatInt(expTime.Unix(), 10)) } var contracts []OptionContract - return contracts, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, common.EncodeURLValues(gateioOptionContracts, params), &contracts) + return contracts, g.SendHTTPRequest(ctx, exchange.RestSpot, publicContractsOptionsEPL, common.EncodeURLValues(gateioOptionContracts, params), &contracts) } // GetOptionsSpecifiedContractDetail query specified contract detail @@ -3297,8 +3181,7 @@ func (g *Gateio) GetOptionsSpecifiedContractDetail(ctx context.Context, contract return nil, errInvalidOrMissingContractParam } var contr *OptionContract - return contr, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, - gateioOptionContracts+"/"+contract.String(), &contr) + return contr, g.SendHTTPRequest(ctx, exchange.RestSpot, publicContractsOptionsEPL, gateioOptionContracts+"/"+contract.String(), &contr) } // GetSettlementHistory retrieves list of settlement history @@ -3321,8 +3204,7 @@ func (g *Gateio) GetSettlementHistory(ctx context.Context, underlying string, of params.Set("to", strconv.FormatInt(to.Unix(), 10)) } var settlements []OptionSettlement - return settlements, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, - common.EncodeURLValues(gateioOptionSettlement, params), &settlements) + return settlements, g.SendHTTPRequest(ctx, exchange.RestSpot, publicSettlementOptionsEPL, common.EncodeURLValues(gateioOptionSettlement, params), &settlements) } // GetOptionsSpecifiedContractsSettlement retrieve a single contract settlement detail passing the underlying and contract name @@ -3337,7 +3219,7 @@ func (g *Gateio) GetOptionsSpecifiedContractsSettlement(ctx context.Context, con params.Set("underlying", underlying) params.Set("at", strconv.FormatInt(at, 10)) var settlement *OptionSettlement - return settlement, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, common.EncodeURLValues(gateioOptionSettlement+"/"+contract.String(), params), &settlement) + return settlement, g.SendHTTPRequest(ctx, exchange.RestSpot, publicSettlementOptionsEPL, common.EncodeURLValues(gateioOptionSettlement+"/"+contract.String(), params), &settlement) } // GetMyOptionsSettlements retrieves accounts option settlements. @@ -3360,7 +3242,7 @@ func (g *Gateio) GetMyOptionsSettlements(ctx context.Context, underlying string, params.Set("limit", strconv.Itoa(int(limit))) } var settlements []MyOptionSettlement - return settlements, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, gateioOptionMySettlements, params, nil, &settlements) + return settlements, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, optionsSettlementsEPL, http.MethodGet, gateioOptionMySettlements, params, nil, &settlements) } // GetOptionsOrderbook returns the orderbook data for the given contract. @@ -3378,13 +3260,13 @@ func (g *Gateio) GetOptionsOrderbook(ctx context.Context, contract currency.Pair } params.Set("with_id", strconv.FormatBool(withOrderbookID)) var response *Orderbook - return response, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, common.EncodeURLValues(gateioOptionsOrderbook, params), &response) + return response, g.SendHTTPRequest(ctx, exchange.RestSpot, publicOrderbookOptionsEPL, common.EncodeURLValues(gateioOptionsOrderbook, params), &response) } // GetOptionAccounts lists option accounts func (g *Gateio) GetOptionAccounts(ctx context.Context) (*OptionAccount, error) { var resp *OptionAccount - return resp, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, gateioOptionAccounts, nil, nil, &resp) + return resp, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, optionsAccountsEPL, http.MethodGet, gateioOptionAccounts, nil, nil, &resp) } // GetAccountChangingHistory retrieves list of account changing history @@ -3406,7 +3288,7 @@ func (g *Gateio) GetAccountChangingHistory(ctx context.Context, offset, limit ui params.Set("to", strconv.FormatInt(to.Unix(), 10)) } var accountBook []AccountBook - return accountBook, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, gateioOptionsAccountbook, params, nil, &accountBook) + return accountBook, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, optionsAccountBooksEPL, http.MethodGet, gateioOptionsAccountbook, params, nil, &accountBook) } // GetUsersPositionSpecifiedUnderlying lists user's positions of specified underlying @@ -3416,7 +3298,7 @@ func (g *Gateio) GetUsersPositionSpecifiedUnderlying(ctx context.Context, underl params.Set("underlying", underlying) } var response []UsersPositionForUnderlying - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, gateioOptionsPosition, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, optionsPositions, http.MethodGet, gateioOptionsPosition, params, nil, &response) } // GetSpecifiedContractPosition retrieves specified contract position @@ -3425,8 +3307,7 @@ func (g *Gateio) GetSpecifiedContractPosition(ctx context.Context, contract curr return nil, errInvalidOrMissingContractParam } var response *UsersPositionForUnderlying - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, - gateioOptionsPosition+"/"+contract.String(), nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, optionsPositions, http.MethodGet, gateioOptionsPosition+"/"+contract.String(), nil, nil, &response) } // GetUsersLiquidationHistoryForSpecifiedUnderlying retrieves user's liquidation history of specified underlying @@ -3440,7 +3321,7 @@ func (g *Gateio) GetUsersLiquidationHistoryForSpecifiedUnderlying(ctx context.Co params.Set("contract", contract.String()) } var response []ContractClosePosition - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, gateioOptionsPositionClose, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, optionsLiquidationHistoryEPL, http.MethodGet, gateioOptionsPositionClose, params, nil, &response) } // PlaceOptionOrder creates an options order @@ -3461,8 +3342,7 @@ func (g *Gateio) PlaceOptionOrder(ctx context.Context, arg *OptionOrderParam) (* arg.Price = 0 } var response *OptionOrderResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPlaceOrdersEPL, http.MethodPost, - gateioOptionsOrders, nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, optionsSubmitOrderEPL, http.MethodPost, gateioOptionsOrders, nil, &arg, &response) } // GetOptionFuturesOrders retrieves futures orders @@ -3491,8 +3371,7 @@ func (g *Gateio) GetOptionFuturesOrders(ctx context.Context, contract currency.P params.Set("to", strconv.FormatInt(to.Unix(), 10)) } var response []OptionOrderResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, - http.MethodGet, gateioOptionsOrders, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, optionsOrdersEPL, http.MethodGet, gateioOptionsOrders, params, nil, &response) } // CancelMultipleOptionOpenOrders cancels all open orders matched @@ -3508,8 +3387,7 @@ func (g *Gateio) CancelMultipleOptionOpenOrders(ctx context.Context, contract cu params.Set("side", side) } var response []OptionOrderResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, - exchange.RestSpot, perpetualSwapCancelOrdersEPL, http.MethodDelete, gateioOptionsOrders, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, optionsCancelOrdersEPL, http.MethodDelete, gateioOptionsOrders, params, nil, &response) } // GetSingleOptionOrder retrieves a single option order @@ -3518,7 +3396,7 @@ func (g *Gateio) GetSingleOptionOrder(ctx context.Context, orderID string) (*Opt return nil, errInvalidOrderID } var o *OptionOrderResponse - return o, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, gateioOptionsOrders+"/"+orderID, nil, nil, &o) + return o, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, optionsOrderEPL, http.MethodGet, gateioOptionsOrders+"/"+orderID, nil, nil, &o) } // CancelOptionSingleOrder cancel a single order. @@ -3527,8 +3405,7 @@ func (g *Gateio) CancelOptionSingleOrder(ctx context.Context, orderID string) (* return nil, errInvalidOrderID } var response *OptionOrderResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapCancelOrdersEPL, http.MethodDelete, - "options/orders/"+orderID, nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, optionsCancelOrderEPL, http.MethodDelete, "options/orders/"+orderID, nil, nil, &response) } // GetOptionsPersonalTradingHistory retrieves personal tradign histories given the underlying{Required}, contract, and other pagination params. @@ -3554,7 +3431,7 @@ func (g *Gateio) GetOptionsPersonalTradingHistory(ctx context.Context, underlyin params.Set("to", strconv.FormatInt(to.Unix(), 10)) } var resp []OptionTradingHistory - return resp, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, perpetualSwapPrivateEPL, http.MethodGet, gateioOptionsMyTrades, params, nil, &resp) + return resp, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, optionsTradingHistoryEPL, http.MethodGet, gateioOptionsMyTrades, params, nil, &resp) } // GetOptionsTickers lists tickers of options contracts @@ -3564,8 +3441,7 @@ func (g *Gateio) GetOptionsTickers(ctx context.Context, underlying string) ([]Op } underlying = strings.ToUpper(underlying) var response []OptionsTicker - return response, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, - gateioOptionsTickers+"?underlying="+underlying, &response) + return response, g.SendHTTPRequest(ctx, exchange.RestSpot, publicTickerOptionsEPL, gateioOptionsTickers+"?underlying="+underlying, &response) } // GetOptionUnderlyingTickers retrieves options underlying ticker @@ -3574,8 +3450,7 @@ func (g *Gateio) GetOptionUnderlyingTickers(ctx context.Context, underlying stri return nil, errInvalidUnderlying } var respos *OptionsUnderlyingTicker - return respos, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, - "options/underlying/tickers/"+underlying, &respos) + return respos, g.SendHTTPRequest(ctx, exchange.RestSpot, publicUnderlyingTickerOptionsEPL, "options/underlying/tickers/"+underlying, &respos) } // GetOptionFuturesCandlesticks retrieves option futures candlesticks @@ -3600,8 +3475,7 @@ func (g *Gateio) GetOptionFuturesCandlesticks(ctx context.Context, contract curr } params.Set("interval", intervalString) var candles []FuturesCandlestick - return candles, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, - common.EncodeURLValues(gateioOptionCandlesticks, params), &candles) + return candles, g.SendHTTPRequest(ctx, exchange.RestSpot, publicCandleSticksOptionsEPL, common.EncodeURLValues(gateioOptionCandlesticks, params), &candles) } // GetOptionFuturesMarkPriceCandlesticks retrieves mark price candlesticks of an underlying @@ -3628,12 +3502,11 @@ func (g *Gateio) GetOptionFuturesMarkPriceCandlesticks(ctx context.Context, unde params.Set("interval", intervalString) } var candles []FuturesCandlestick - return candles, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, - common.EncodeURLValues(gateioOptionUnderlyingCandlesticks, params), &candles) + return candles, g.SendHTTPRequest(ctx, exchange.RestSpot, publicMarkpriceCandleSticksOptionsEPL, common.EncodeURLValues(gateioOptionUnderlyingCandlesticks, params), &candles) } // GetOptionsTradeHistory retrieves options trade history -func (g *Gateio) GetOptionsTradeHistory(ctx context.Context, contract /*C is call, while P is put*/ currency.Pair, callType string, offset, limit uint64, from, to time.Time) ([]TradingHistoryItem, error) { +func (g *Gateio) GetOptionsTradeHistory(ctx context.Context, contract currency.Pair, callType string, offset, limit uint64, from, to time.Time) ([]TradingHistoryItem, error) { params := url.Values{} callType = strings.ToUpper(callType) if callType == "C" || callType == "P" { @@ -3655,7 +3528,7 @@ func (g *Gateio) GetOptionsTradeHistory(ctx context.Context, contract /*C is cal params.Set("to", strconv.FormatInt(to.Unix(), 10)) } var trades []TradingHistoryItem - return trades, g.SendHTTPRequest(ctx, exchange.RestSpot, perpetualSwapDefaultEPL, common.EncodeURLValues(gateioOptionsTrades, params), &trades) + return trades, g.SendHTTPRequest(ctx, exchange.RestSpot, publicTradeHistoryOptionsEPL, common.EncodeURLValues(gateioOptionsTrades, params), &trades) } // ********************************** Flash_SWAP ************************* @@ -3663,7 +3536,7 @@ func (g *Gateio) GetOptionsTradeHistory(ctx context.Context, contract /*C is cal // GetSupportedFlashSwapCurrencies retrieves all supported currencies in flash swap func (g *Gateio) GetSupportedFlashSwapCurrencies(ctx context.Context) ([]SwapCurrencies, error) { var currencies []SwapCurrencies - return currencies, g.SendHTTPRequest(ctx, exchange.RestSpot, spotDefaultEPL, gateioFlashSwapCurrencies, ¤cies) + return currencies, g.SendHTTPRequest(ctx, exchange.RestSpot, publicFlashSwapEPL, gateioFlashSwapCurrencies, ¤cies) } // CreateFlashSwapOrder creates a new flash swap order @@ -3685,7 +3558,7 @@ func (g *Gateio) CreateFlashSwapOrder(ctx context.Context, arg FlashSwapOrderPar return nil, fmt.Errorf("%w, buy_amount amount can not be less than or equal to 0", errInvalidAmount) } var response *FlashSwapOrderResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotDefaultEPL, http.MethodPost, gateioFlashSwapOrders, nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, flashSwapOrderEPL, http.MethodPost, gateioFlashSwapOrders, nil, &arg, &response) } // GetAllFlashSwapOrders retrieves list of flash swap orders filtered by the params @@ -3708,7 +3581,7 @@ func (g *Gateio) GetAllFlashSwapOrders(ctx context.Context, status int, sellCurr params.Set("limit", strconv.FormatUint(limit, 10)) } var response []FlashSwapOrderResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, gateioFlashSwapOrders, params, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, flashGetOrdersEPL, http.MethodGet, gateioFlashSwapOrders, params, nil, &response) } // GetSingleFlashSwapOrder get a single flash swap order's detail @@ -3717,8 +3590,7 @@ func (g *Gateio) GetSingleFlashSwapOrder(ctx context.Context, orderID string) (* return nil, fmt.Errorf("%w, flash order order_id must not be empty", errInvalidOrderID) } var response *FlashSwapOrderResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodGet, - gateioFlashSwapOrders+"/"+orderID, nil, nil, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, flashGetOrderEPL, http.MethodGet, gateioFlashSwapOrders+"/"+orderID, nil, nil, &response) } // InitiateFlashSwapOrderReview initiate a flash swap order preview @@ -3733,7 +3605,7 @@ func (g *Gateio) InitiateFlashSwapOrderReview(ctx context.Context, arg FlashSwap return nil, fmt.Errorf("%w, sell currency can not empty", currency.ErrCurrencyCodeEmpty) } var response *InitFlashSwapOrderPreviewResponse - return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, spotPrivateEPL, http.MethodPost, gateioFlashSwapOrdersPreview, nil, &arg, &response) + return response, g.SendAuthenticatedHTTPRequest(ctx, exchange.RestSpot, flashOrderReviewEPL, http.MethodPost, gateioFlashSwapOrdersPreview, nil, &arg, &response) } // IsValidPairString returns true if the string represents a valid currency pair diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index 102412dd162..76343bb7203 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -332,7 +332,7 @@ func TestCreateBatchOrders(t *testing.T) { func TestGetSpotOpenOrders(t *testing.T) { t.Parallel() sharedtestvalues.SkipTestIfCredentialsUnset(t, g) - if _, err := g.GateioSpotOpenOrders(context.Background(), 0, 0, false); err != nil { + if _, err := g.GetSpotOpenOrders(context.Background(), 0, 0, false); err != nil { t.Errorf("%s GetSpotOpenOrders() error %v", g.Name, err) } } diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 02d2b760fa6..b97109e8988 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -151,7 +151,7 @@ func (g *Gateio) SetDefaults() { } g.Requester, err = request.New(g.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), - request.WithLimiter(GetRateLimit()), + request.WithLimiter(packageRateLimits), ) if err != nil { log.Errorln(log.ExchangeSys, err) @@ -1658,7 +1658,7 @@ func (g *Gateio) GetActiveOrders(ctx context.Context, req *order.MultiOrderReque switch req.AssetType { case asset.Spot, asset.Margin, asset.CrossMargin: var spotOrders []SpotOrdersDetail - spotOrders, err = g.GateioSpotOpenOrders(ctx, 0, 0, req.AssetType == asset.CrossMargin) + spotOrders, err = g.GetSpotOpenOrders(ctx, 0, 0, req.AssetType == asset.CrossMargin) if err != nil { return nil, err } diff --git a/exchanges/gateio/ratelimiter.go b/exchanges/gateio/ratelimiter.go index 2c554b38e5d..c42fa63d074 100644 --- a/exchanges/gateio/ratelimiter.go +++ b/exchanges/gateio/ratelimiter.go @@ -6,49 +6,404 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/request" ) -// GateIO endpoints limits. +// GateIO endpoints limits. See: https://www.gate.io/docs/developers/apiv4/en/#frequency-limit-rule const ( - spotDefaultEPL request.EndpointLimit = iota - spotPrivateEPL - spotPlaceOrdersEPL - spotCancelOrdersEPL - perpetualSwapDefaultEPL - perpetualSwapPlaceOrdersEPL - perpetualSwapPrivateEPL - perpetualSwapCancelOrdersEPL - walletEPL - withdrawalEPL - - // Request rates per interval - - spotPublicRate = 900 - spotPrivateRate = 900 - spotPlaceOrdersRate = 10 - spotCancelOrdersRate = 500 - perpetualSwapPublicRate = 300 - perpetualSwapPlaceOrdersRate = 100 - perpetualSwapPrivateRate = 400 - perpetualSwapCancelOrdersRate = 400 - walletRate = 200 - withdrawalRate = 1 - - // interval - oneSecondInterval = time.Second - threeSecondsInterval = time.Second * 3 + publicTickersSpotEPL request.EndpointLimit = iota + 1 + publicOrderbookSpotEPL + publicMarketTradesSpotEPL + publicCandleStickSpotEPL + publicCurrencyPairDetailSpotEPL + publicListCurrencyPairsSpotEPL + publicCurrenciesSpotEPL + + publicCurrencyPairsMarginEPL + publicOrderbookMarginEPL + + publicInsuranceDeliveryEPL + publicDeliveryContractsEPL + publicOrderbookDeliveryEPL + publicTradingHistoryDeliveryEPL + publicCandleSticksDeliveryEPL + publicTickersDeliveryEPL + + publicFuturesContractsEPL + publicOrderbookFuturesEPL + publicTradingHistoryFuturesEPL + publicCandleSticksFuturesEPL + publicPremiumIndexEPL + publicTickersFuturesEPL + publicFundingRatesEPL + publicInsuranceFuturesEPL + publicStatsFuturesEPL + publicIndexConstituentsEPL + publicLiquidationHistoryEPL + + publicUnderlyingOptionsEPL + publicExpirationOptionsEPL + publicContractsOptionsEPL + publicSettlementOptionsEPL + publicOrderbookOptionsEPL + publicTickerOptionsEPL + publicUnderlyingTickerOptionsEPL + publicCandleSticksOptionsEPL + publicMarkpriceCandleSticksOptionsEPL + publicTradeHistoryOptionsEPL + + publicGetServerTimeEPL + publicFlashSwapEPL + publicListCurrencyChainEPL + + walletDepositAddressEPL + walletWithdrawalRecordsEPL + walletDepositRecordsEPL + walletTransferCurrencyEPL + walletSubAccountTransferEPL + walletSubAccountTransferHistoryEPL + walletSubAccountToSubAccountTransferEPL + walletWithdrawStatusEPL + walletSubAccountBalancesEPL + walletSubAccountMarginBalancesEPL + walletSubAccountFuturesBalancesEPL + walletSubAccountCrossMarginBalancesEPL + walletSavedAddressesEPL + walletTradingFeeEPL + walletTotalBalanceEPL + walletWithdrawEPL + walletCancelWithdrawEPL + + subAccountEPL + + spotTradingFeeEPL + spotAccountsEPL + spotGetOpenOrdersEPL + spotClosePositionEPL + spotBatchOrdersEPL + spotPlaceOrderEPL + spotGetOrdersEPL + spotCancelAllOpenOrdersEPL + spotCancelBatchOrdersEPL + spotGetOrderEPL + spotAmendOrderEPL + spotCancelSingleOrderEPL + spotTradingHistoryEPL + spotCountdownCancelEPL + spotCreateTriggerOrderEPL + spotGetTriggerOrderListEPL + spotCancelTriggerOrdersEPL + spotGetTriggerOrderEPL + spotCancelTriggerOrderEPL + + marginAccountListEPL + marginAccountBalanceEPL + marginFundingAccountListEPL + marginLendBorrowEPL + marginAllLoansEPL + marginMergeLendingLoansEPL + marginGetLoanEPL + marginModifyLoanEPL + marginCancelLoanEPL + marginRepayLoanEPL + marginListLoansEPL + marginRepaymentRecordEPL + marginSingleRecordEPL + marginModifyLoanRecordEPL + marginAutoRepayEPL + marginGetAutoRepaySettingsEPL + marginGetMaxTransferEPL + marginGetMaxBorrowEPL + marginSupportedCurrencyCrossListEPL + marginSupportedCurrencyCrossEPL + marginAccountsEPL + marginAccountHistoryEPL + marginCreateCrossBorrowLoanEPL + marginExecuteRepaymentsEPL + marginGetCrossMarginRepaymentsEPL + marginGetMaxTransferCrossEPL + marginGetMaxBorrowCrossEPL + marginGetCrossBorrowHistoryEPL + marginGetBorrowEPL + + flashSwapOrderEPL + flashGetOrdersEPL + flashGetOrderEPL + flashOrderReviewEPL + + privateUnifiedSpotEPL + + perpetualAccountEPL + perpetualAccountBooksEPL + perpetualPositionsEPL + perpetualPositionEPL + perpetualUpdateMarginEPL + perpetualUpdateLeverageEPL + perpetualUpdateRiskEPL + perpetualToggleDualModeEPL + perpetualPositionsDualModeEPL + perpetualUpdateMarginDualModeEPL + perpetualUpdateLeverageDualModeEPL + perpetualUpdateRiskDualModeEPL + perpetualSubmitOrderEPL + perpetualGetOrdersEPL + perpetualSubmitBatchOrdersEPL + perpetualFetchOrderEPL + perpetualCancelOrderEPL + perpetualAmendOrderEPL + perpetualTradingHistoryEPL + perpetualClosePositionEPL + perpetualLiquidationHistoryEPL + perpetualCancelTriggerOrdersEPL + perpetualSubmitTriggerOrderEPL + perpetualListOpenOrdersEPL + perpetualCancelOpenOrdersEPL + perpetualGetTriggerOrderEPL + perpetualCancelTriggerOrderEPL + + deliveryAccountEPL + deliveryAccountBooksEPL + deliveryPositionsEPL + deliveryUpdateMarginEPL + deliveryUpdateLeverageEPL + deliveryUpdateRiskLimitEPL + deliverySubmitOrderEPL + deliveryGetOrdersEPL + deliveryCancelOrdersEPL + deliveryGetOrderEPL + deliveryCancelOrderEPL + deliveryTradingHistoryEPL + deliveryCloseHistoryEPL + deliveryLiquidationHistoryEPL + deliverySettlementHistoryEPL + deliveryGetTriggerOrdersEPL + deliveryAutoOrdersEPL + deliveryCancelTriggerOrdersEPL + deliveryGetTriggerOrderEPL + deliveryCancelTriggerOrderEPL + + optionsSettlementsEPL + optionsAccountsEPL + optionsAccountBooksEPL + optionsPositions + optionsLiquidationHistoryEPL + optionsSubmitOrderEPL + optionsOrdersEPL + optionsCancelOrdersEPL + optionsOrderEPL + optionsCancelOrderEPL + optionsTradingHistoryEPL ) -// GetRateLimit returns the rate limiter for the exchange -func GetRateLimit() request.RateLimitDefinitions { - return request.RateLimitDefinitions{ - spotDefaultEPL: request.NewRateLimitWithWeight(oneSecondInterval, spotPublicRate, 1), - spotPrivateEPL: request.NewRateLimitWithWeight(oneSecondInterval, spotPrivateRate, 1), - spotPlaceOrdersEPL: request.NewRateLimitWithWeight(oneSecondInterval, spotPlaceOrdersRate, 1), - spotCancelOrdersEPL: request.NewRateLimitWithWeight(oneSecondInterval, spotCancelOrdersRate, 1), - perpetualSwapDefaultEPL: request.NewRateLimitWithWeight(oneSecondInterval, perpetualSwapPublicRate, 1), - perpetualSwapPlaceOrdersEPL: request.NewRateLimitWithWeight(oneSecondInterval, perpetualSwapPlaceOrdersRate, 1), - perpetualSwapPrivateEPL: request.NewRateLimitWithWeight(oneSecondInterval, perpetualSwapPrivateRate, 1), - perpetualSwapCancelOrdersEPL: request.NewRateLimitWithWeight(oneSecondInterval, perpetualSwapCancelOrdersRate, 1), - walletEPL: request.NewRateLimitWithWeight(oneSecondInterval, walletRate, 1), - withdrawalEPL: request.NewRateLimitWithWeight(threeSecondsInterval, withdrawalRate, 1), - } +// package level rate limits for REST API +var packageRateLimits = request.RateLimitDefinitions{ + publicOrderbookSpotEPL: standardRateLimit(), + publicMarketTradesSpotEPL: standardRateLimit(), + publicCandleStickSpotEPL: standardRateLimit(), + publicTickersSpotEPL: standardRateLimit(), + publicCurrencyPairDetailSpotEPL: standardRateLimit(), + publicListCurrencyPairsSpotEPL: standardRateLimit(), + publicCurrenciesSpotEPL: standardRateLimit(), + + publicCurrencyPairsMarginEPL: standardRateLimit(), + publicOrderbookMarginEPL: standardRateLimit(), + + publicInsuranceDeliveryEPL: standardRateLimit(), + publicDeliveryContractsEPL: standardRateLimit(), + publicOrderbookDeliveryEPL: standardRateLimit(), + publicCandleSticksDeliveryEPL: standardRateLimit(), + publicTickersDeliveryEPL: standardRateLimit(), + + publicFuturesContractsEPL: standardRateLimit(), + publicOrderbookFuturesEPL: standardRateLimit(), + publicTradingHistoryFuturesEPL: standardRateLimit(), + publicCandleSticksFuturesEPL: standardRateLimit(), + publicPremiumIndexEPL: standardRateLimit(), + publicTickersFuturesEPL: standardRateLimit(), + publicFundingRatesEPL: standardRateLimit(), + publicInsuranceFuturesEPL: standardRateLimit(), + publicStatsFuturesEPL: standardRateLimit(), + publicIndexConstituentsEPL: standardRateLimit(), + publicLiquidationHistoryEPL: standardRateLimit(), + + publicUnderlyingOptionsEPL: standardRateLimit(), + publicExpirationOptionsEPL: standardRateLimit(), + publicContractsOptionsEPL: standardRateLimit(), + publicSettlementOptionsEPL: standardRateLimit(), + publicOrderbookOptionsEPL: standardRateLimit(), + publicTickerOptionsEPL: standardRateLimit(), + publicUnderlyingTickerOptionsEPL: standardRateLimit(), + publicCandleSticksOptionsEPL: standardRateLimit(), + publicMarkpriceCandleSticksOptionsEPL: standardRateLimit(), + publicTradeHistoryOptionsEPL: standardRateLimit(), + + publicGetServerTimeEPL: standardRateLimit(), + publicFlashSwapEPL: standardRateLimit(), + publicListCurrencyChainEPL: standardRateLimit(), + + walletDepositAddressEPL: standardRateLimit(), + walletWithdrawalRecordsEPL: standardRateLimit(), + walletDepositRecordsEPL: standardRateLimit(), + walletTransferCurrencyEPL: personalAccountRateLimit(), + walletSubAccountTransferEPL: personalAccountRateLimit(), + walletSubAccountTransferHistoryEPL: standardRateLimit(), + walletSubAccountToSubAccountTransferEPL: personalAccountRateLimit(), + walletWithdrawStatusEPL: standardRateLimit(), + walletSubAccountBalancesEPL: personalAccountRateLimit(), + walletSubAccountMarginBalancesEPL: personalAccountRateLimit(), + walletSubAccountFuturesBalancesEPL: personalAccountRateLimit(), + walletSubAccountCrossMarginBalancesEPL: personalAccountRateLimit(), + walletSavedAddressesEPL: standardRateLimit(), + walletTradingFeeEPL: standardRateLimit(), + walletTotalBalanceEPL: personalAccountRateLimit(), + walletWithdrawEPL: request.NewRateLimitWithWeight(time.Second*3, 1, 1), // 1r/3s + walletCancelWithdrawEPL: standardRateLimit(), + + subAccountEPL: personalAccountRateLimit(), + + spotTradingFeeEPL: standardRateLimit(), + spotAccountsEPL: standardRateLimit(), + spotGetOpenOrdersEPL: standardRateLimit(), + spotClosePositionEPL: orderCloseRateLimit(), + spotBatchOrdersEPL: spotOrderPlacementRateLimit(), + spotPlaceOrderEPL: spotOrderPlacementRateLimit(), + spotGetOrdersEPL: standardRateLimit(), + spotCancelAllOpenOrdersEPL: orderCloseRateLimit(), + spotCancelBatchOrdersEPL: orderCloseRateLimit(), + spotGetOrderEPL: standardRateLimit(), + spotAmendOrderEPL: spotOrderPlacementRateLimit(), + spotCancelSingleOrderEPL: orderCloseRateLimit(), + spotTradingHistoryEPL: standardRateLimit(), + spotCountdownCancelEPL: orderCloseRateLimit(), + spotCreateTriggerOrderEPL: spotOrderPlacementRateLimit(), + spotGetTriggerOrderListEPL: standardRateLimit(), + spotCancelTriggerOrdersEPL: orderCloseRateLimit(), + spotGetTriggerOrderEPL: standardRateLimit(), + spotCancelTriggerOrderEPL: orderCloseRateLimit(), + + marginAccountListEPL: otherPrivateEndpointRateLimit(), + marginAccountBalanceEPL: otherPrivateEndpointRateLimit(), + marginFundingAccountListEPL: otherPrivateEndpointRateLimit(), + marginLendBorrowEPL: otherPrivateEndpointRateLimit(), + marginAllLoansEPL: otherPrivateEndpointRateLimit(), + marginMergeLendingLoansEPL: otherPrivateEndpointRateLimit(), + marginGetLoanEPL: otherPrivateEndpointRateLimit(), + marginModifyLoanEPL: otherPrivateEndpointRateLimit(), + marginCancelLoanEPL: otherPrivateEndpointRateLimit(), + marginRepayLoanEPL: otherPrivateEndpointRateLimit(), + marginListLoansEPL: otherPrivateEndpointRateLimit(), + marginRepaymentRecordEPL: otherPrivateEndpointRateLimit(), + marginSingleRecordEPL: otherPrivateEndpointRateLimit(), + marginModifyLoanRecordEPL: otherPrivateEndpointRateLimit(), + marginAutoRepayEPL: otherPrivateEndpointRateLimit(), + marginGetAutoRepaySettingsEPL: otherPrivateEndpointRateLimit(), + marginGetMaxTransferEPL: otherPrivateEndpointRateLimit(), + marginGetMaxBorrowEPL: otherPrivateEndpointRateLimit(), + marginSupportedCurrencyCrossListEPL: otherPrivateEndpointRateLimit(), + marginSupportedCurrencyCrossEPL: otherPrivateEndpointRateLimit(), + marginAccountsEPL: otherPrivateEndpointRateLimit(), + marginAccountHistoryEPL: otherPrivateEndpointRateLimit(), + marginCreateCrossBorrowLoanEPL: otherPrivateEndpointRateLimit(), + marginExecuteRepaymentsEPL: otherPrivateEndpointRateLimit(), + marginGetCrossMarginRepaymentsEPL: otherPrivateEndpointRateLimit(), + marginGetMaxTransferCrossEPL: otherPrivateEndpointRateLimit(), + marginGetMaxBorrowCrossEPL: otherPrivateEndpointRateLimit(), + marginGetCrossBorrowHistoryEPL: otherPrivateEndpointRateLimit(), + marginGetBorrowEPL: otherPrivateEndpointRateLimit(), + + flashSwapOrderEPL: otherPrivateEndpointRateLimit(), + flashGetOrdersEPL: otherPrivateEndpointRateLimit(), + flashGetOrderEPL: otherPrivateEndpointRateLimit(), + flashOrderReviewEPL: otherPrivateEndpointRateLimit(), + + perpetualAccountEPL: standardRateLimit(), + perpetualAccountBooksEPL: standardRateLimit(), + perpetualPositionsEPL: standardRateLimit(), + perpetualPositionEPL: standardRateLimit(), + perpetualUpdateMarginEPL: standardRateLimit(), + perpetualUpdateLeverageEPL: standardRateLimit(), + perpetualUpdateRiskEPL: standardRateLimit(), + perpetualToggleDualModeEPL: standardRateLimit(), + perpetualPositionsDualModeEPL: standardRateLimit(), + perpetualUpdateMarginDualModeEPL: standardRateLimit(), + perpetualUpdateLeverageDualModeEPL: standardRateLimit(), + perpetualUpdateRiskDualModeEPL: standardRateLimit(), + perpetualSubmitOrderEPL: perpetualOrderplacementRateLimit(), + perpetualGetOrdersEPL: standardRateLimit(), + perpetualSubmitBatchOrdersEPL: perpetualOrderplacementRateLimit(), + perpetualFetchOrderEPL: standardRateLimit(), + perpetualCancelOrderEPL: orderCloseRateLimit(), + perpetualAmendOrderEPL: perpetualOrderplacementRateLimit(), + perpetualTradingHistoryEPL: standardRateLimit(), + perpetualClosePositionEPL: orderCloseRateLimit(), + perpetualLiquidationHistoryEPL: standardRateLimit(), + perpetualCancelTriggerOrdersEPL: orderCloseRateLimit(), + perpetualSubmitTriggerOrderEPL: perpetualOrderplacementRateLimit(), + perpetualListOpenOrdersEPL: standardRateLimit(), + perpetualCancelOpenOrdersEPL: orderCloseRateLimit(), + perpetualGetTriggerOrderEPL: standardRateLimit(), + perpetualCancelTriggerOrderEPL: orderCloseRateLimit(), + + deliveryAccountEPL: standardRateLimit(), + deliveryAccountBooksEPL: standardRateLimit(), + deliveryPositionsEPL: standardRateLimit(), + deliveryUpdateMarginEPL: standardRateLimit(), + deliveryUpdateLeverageEPL: standardRateLimit(), + deliveryUpdateRiskLimitEPL: standardRateLimit(), + deliverySubmitOrderEPL: deliverySubmitCancelAmendRateLimit(), + deliveryGetOrdersEPL: standardRateLimit(), + deliveryCancelOrdersEPL: deliverySubmitCancelAmendRateLimit(), + deliveryGetOrderEPL: standardRateLimit(), + deliveryCancelOrderEPL: deliverySubmitCancelAmendRateLimit(), + deliveryTradingHistoryEPL: standardRateLimit(), + deliveryCloseHistoryEPL: standardRateLimit(), + deliveryLiquidationHistoryEPL: standardRateLimit(), + deliverySettlementHistoryEPL: standardRateLimit(), + deliveryGetTriggerOrdersEPL: standardRateLimit(), + deliveryAutoOrdersEPL: standardRateLimit(), + deliveryCancelTriggerOrdersEPL: deliverySubmitCancelAmendRateLimit(), + deliveryGetTriggerOrderEPL: standardRateLimit(), + deliveryCancelTriggerOrderEPL: deliverySubmitCancelAmendRateLimit(), + + optionsSettlementsEPL: standardRateLimit(), + optionsAccountsEPL: standardRateLimit(), + optionsAccountBooksEPL: standardRateLimit(), + optionsPositions: standardRateLimit(), + optionsLiquidationHistoryEPL: standardRateLimit(), + optionsSubmitOrderEPL: optionsSubmitCancelAmendRateLimit(), + optionsOrdersEPL: standardRateLimit(), + optionsCancelOrdersEPL: optionsSubmitCancelAmendRateLimit(), + optionsCancelOrderEPL: optionsSubmitCancelAmendRateLimit(), + optionsTradingHistoryEPL: standardRateLimit(), + + privateUnifiedSpotEPL: standardRateLimit(), +} + +func standardRateLimit() *request.RateLimiterWithWeight { + return request.NewRateLimitWithWeight(time.Second*10, 200, 1) +} + +func personalAccountRateLimit() *request.RateLimiterWithWeight { + return request.NewRateLimitWithWeight(time.Second*10, 80, 1) +} + +func orderCloseRateLimit() *request.RateLimiterWithWeight { + return request.NewRateLimitWithWeight(time.Second, 200, 1) +} + +func spotOrderPlacementRateLimit() *request.RateLimiterWithWeight { + return request.NewRateLimitWithWeight(time.Second, 10, 1) +} + +func otherPrivateEndpointRateLimit() *request.RateLimiterWithWeight { + return request.NewRateLimitWithWeight(time.Second*10, 150, 1) +} + +func perpetualOrderplacementRateLimit() *request.RateLimiterWithWeight { + return request.NewRateLimitWithWeight(time.Second, 100, 1) +} + +func deliverySubmitCancelAmendRateLimit() *request.RateLimiterWithWeight { + return request.NewRateLimitWithWeight(time.Second*10, 500, 1) +} + +func optionsSubmitCancelAmendRateLimit() *request.RateLimiterWithWeight { + return request.NewRateLimitWithWeight(time.Second, 200, 1) } From 6803a2a988786b1e6c5b47050a4d4fcd0c156890 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 28 Nov 2024 10:37:09 +1100 Subject: [PATCH 123/138] Add test and missing --- exchanges/gateio/ratelimiter.go | 12 +++++++----- exchanges/gateio/ratelimiter_test.go | 16 ++++++++++++++++ 2 files changed, 23 insertions(+), 5 deletions(-) create mode 100644 exchanges/gateio/ratelimiter_test.go diff --git a/exchanges/gateio/ratelimiter.go b/exchanges/gateio/ratelimiter.go index c42fa63d074..5826a4dbc14 100644 --- a/exchanges/gateio/ratelimiter.go +++ b/exchanges/gateio/ratelimiter.go @@ -205,11 +205,12 @@ var packageRateLimits = request.RateLimitDefinitions{ publicCurrencyPairsMarginEPL: standardRateLimit(), publicOrderbookMarginEPL: standardRateLimit(), - publicInsuranceDeliveryEPL: standardRateLimit(), - publicDeliveryContractsEPL: standardRateLimit(), - publicOrderbookDeliveryEPL: standardRateLimit(), - publicCandleSticksDeliveryEPL: standardRateLimit(), - publicTickersDeliveryEPL: standardRateLimit(), + publicInsuranceDeliveryEPL: standardRateLimit(), + publicDeliveryContractsEPL: standardRateLimit(), + publicOrderbookDeliveryEPL: standardRateLimit(), + publicTradingHistoryDeliveryEPL: standardRateLimit(), + publicCandleSticksDeliveryEPL: standardRateLimit(), + publicTickersDeliveryEPL: standardRateLimit(), publicFuturesContractsEPL: standardRateLimit(), publicOrderbookFuturesEPL: standardRateLimit(), @@ -370,6 +371,7 @@ var packageRateLimits = request.RateLimitDefinitions{ optionsSubmitOrderEPL: optionsSubmitCancelAmendRateLimit(), optionsOrdersEPL: standardRateLimit(), optionsCancelOrdersEPL: optionsSubmitCancelAmendRateLimit(), + optionsOrderEPL: standardRateLimit(), optionsCancelOrderEPL: optionsSubmitCancelAmendRateLimit(), optionsTradingHistoryEPL: standardRateLimit(), diff --git a/exchanges/gateio/ratelimiter_test.go b/exchanges/gateio/ratelimiter_test.go new file mode 100644 index 00000000000..926c815284d --- /dev/null +++ b/exchanges/gateio/ratelimiter_test.go @@ -0,0 +1,16 @@ +package gateio + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRateLimits(t *testing.T) { + for epl := range optionsTradingHistoryEPL { + if epl == 0 { + continue + } + assert.NotEmptyf(t, packageRateLimits[epl], "Empty rate limit not found for const %v", epl) + } +} From 4baa7a2099b10568a1d05b497e135f68290b2b8f Mon Sep 17 00:00:00 2001 From: shazbert Date: Thu, 5 Dec 2024 16:09:16 +1100 Subject: [PATCH 124/138] Shared REST rate limit definitions with Websocket service, set lookup item to nil for systems that do not require rate limiting; add glorious nit --- exchanges/gateio/gateio_websocket.go | 10 ++++------ exchanges/gateio/gateio_wrapper.go | 6 +----- exchanges/gateio/ratelimiter.go | 10 +++++++++- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 1711dcc501f..932768ec2a6 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -24,7 +24,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" - "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" @@ -32,8 +31,7 @@ import ( ) const ( - gateioWebsocketEndpoint = "wss://api.gateio.ws/ws/v4/" - gateioWebsocketRateLimit = 120 * time.Millisecond + gateioWebsocketEndpoint = "wss://api.gateio.ws/ws/v4/" spotPingChannel = "spot.ping" spotPongChannel = "spot.pong" @@ -90,7 +88,7 @@ func (g *Gateio) WsConnectSpot(ctx context.Context, conn stream.Connection) erro if err != nil { return err } - conn.SetupPingHandler(request.Unset, stream.PingHandler{ + conn.SetupPingHandler(websocketRateLimitNotNeededEPL, stream.PingHandler{ Websocket: true, Delay: time.Second * 15, Message: pingMessage, @@ -587,7 +585,7 @@ func (g *Gateio) manageSubs(ctx context.Context, event string, conn stream.Conne if err != nil { return err } - result, err := conn.SendMessageReturnResponse(ctx, request.Unset, msg.ID, msg) + result, err := conn.SendMessageReturnResponse(ctx, websocketRateLimitNotNeededEPL, msg.ID, msg) if err != nil { return err } @@ -698,7 +696,7 @@ func (g *Gateio) handleSubscription(ctx context.Context, conn stream.Connection, } var errs error for k := range payloads { - result, err := conn.SendMessageReturnResponse(ctx, request.Unset, payloads[k].ID, payloads[k]) + result, err := conn.SendMessageReturnResponse(ctx, websocketRateLimitNotNeededEPL, payloads[k].ID, payloads[k]) if err != nil { errs = common.AppendError(errs, err) continue diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index b97109e8988..bb28b848ea7 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -203,6 +203,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { FillsFeed: g.Features.Enabled.FillsFeed, TradeFeed: g.Features.Enabled.TradeFeed, UseMultiConnectionManagement: true, + RateLimitDefinitions: packageRateLimits, }) if err != nil { return err @@ -210,7 +211,6 @@ func (g *Gateio) Setup(exch *config.Exchange) error { // Spot connection err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: gateioWebsocketEndpoint, - RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, Handler: g.WsHandleSpotData, @@ -226,7 +226,6 @@ func (g *Gateio) Setup(exch *config.Exchange) error { // Futures connection - USDT margined err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: futuresWebsocketUsdtURL, - RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, Handler: func(ctx context.Context, incoming []byte) error { @@ -245,7 +244,6 @@ func (g *Gateio) Setup(exch *config.Exchange) error { // Futures connection - BTC margined err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: futuresWebsocketBtcURL, - RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, Handler: func(ctx context.Context, incoming []byte) error { @@ -265,7 +263,6 @@ func (g *Gateio) Setup(exch *config.Exchange) error { // Futures connection - Delivery - USDT margined err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: deliveryRealUSDTTradingURL, - RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, Handler: func(ctx context.Context, incoming []byte) error { @@ -284,7 +281,6 @@ func (g *Gateio) Setup(exch *config.Exchange) error { // Futures connection - Options return g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: optionsWebsocketURL, - RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, Handler: g.WsHandleOptionsData, diff --git a/exchanges/gateio/ratelimiter.go b/exchanges/gateio/ratelimiter.go index 5826a4dbc14..870b4d9bc1a 100644 --- a/exchanges/gateio/ratelimiter.go +++ b/exchanges/gateio/ratelimiter.go @@ -190,6 +190,8 @@ const ( optionsOrderEPL optionsCancelOrderEPL optionsTradingHistoryEPL + + websocketRateLimitNotNeededEPL ) // package level rate limits for REST API @@ -254,7 +256,7 @@ var packageRateLimits = request.RateLimitDefinitions{ walletSavedAddressesEPL: standardRateLimit(), walletTradingFeeEPL: standardRateLimit(), walletTotalBalanceEPL: personalAccountRateLimit(), - walletWithdrawEPL: request.NewRateLimitWithWeight(time.Second*3, 1, 1), // 1r/3s + walletWithdrawEPL: withdrawFromWalletRateLimit(), walletCancelWithdrawEPL: standardRateLimit(), subAccountEPL: personalAccountRateLimit(), @@ -376,6 +378,8 @@ var packageRateLimits = request.RateLimitDefinitions{ optionsTradingHistoryEPL: standardRateLimit(), privateUnifiedSpotEPL: standardRateLimit(), + + websocketRateLimitNotNeededEPL: nil, // no rate limit for certain websocket functions } func standardRateLimit() *request.RateLimiterWithWeight { @@ -409,3 +413,7 @@ func deliverySubmitCancelAmendRateLimit() *request.RateLimiterWithWeight { func optionsSubmitCancelAmendRateLimit() *request.RateLimiterWithWeight { return request.NewRateLimitWithWeight(time.Second, 200, 1) } + +func withdrawFromWalletRateLimit() *request.RateLimiterWithWeight { + return request.NewRateLimitWithWeight(time.Second*3, 1, 1) +} From 9c56cbb413e7bca41aa42340d896cc10cb686561 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Mon, 9 Dec 2024 11:44:48 +1100 Subject: [PATCH 125/138] integrate rate limits for websocket trading spot --- exchanges/gateio/gateio_websocket.go | 4 ++-- .../gateio_websocket_delivery_futures.go | 3 +-- exchanges/gateio/gateio_websocket_futures.go | 3 +-- exchanges/gateio/gateio_websocket_option.go | 3 +-- .../gateio/gateio_websocket_request_spot.go | 18 +++++++++--------- 5 files changed, 14 insertions(+), 17 deletions(-) diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 59fe21297b0..ac807c61f86 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -134,7 +134,7 @@ func (g *Gateio) websocketLogin(ctx context.Context, conn stream.Connection, cha req := WebsocketRequest{Time: tn, Channel: channel, Event: "api", Payload: payload} - resp, err := conn.SendMessageReturnResponse(ctx, request.Unset, req.Payload.RequestID, req) + resp, err := conn.SendMessageReturnResponse(ctx, websocketRateLimitNotNeededEPL, req.Payload.RequestID, req) if err != nil { return err } @@ -181,7 +181,7 @@ func (g *Gateio) WsHandleSpotData(_ context.Context, respRaw []byte) error { switch push.Channel { // TODO: Convert function params below to only use push.Result case spotTickerChannel: - return g.processTicker(push.Result, push.Time.Time()) + return g.processTicker(push.Result, push.TimeMs.Time()) case spotTradesChannel: return g.processTrades(push.Result) case spotCandlesticksChannel: diff --git a/exchanges/gateio/gateio_websocket_delivery_futures.go b/exchanges/gateio/gateio_websocket_delivery_futures.go index d4e1c8abb33..eeb1a706053 100644 --- a/exchanges/gateio/gateio_websocket_delivery_futures.go +++ b/exchanges/gateio/gateio_websocket_delivery_futures.go @@ -13,7 +13,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" - "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" ) @@ -55,7 +54,7 @@ func (g *Gateio) WsDeliveryFuturesConnect(ctx context.Context, conn stream.Conne if err != nil { return err } - conn.SetupPingHandler(request.Unset, stream.PingHandler{ + conn.SetupPingHandler(websocketRateLimitNotNeededEPL, stream.PingHandler{ Websocket: true, Delay: time.Second * 5, MessageType: websocket.PingMessage, diff --git a/exchanges/gateio/gateio_websocket_futures.go b/exchanges/gateio/gateio_websocket_futures.go index 520eb816d58..ba393277553 100644 --- a/exchanges/gateio/gateio_websocket_futures.go +++ b/exchanges/gateio/gateio_websocket_futures.go @@ -19,7 +19,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" - "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" @@ -76,7 +75,7 @@ func (g *Gateio) WsFuturesConnect(ctx context.Context, conn stream.Connection) e if err != nil { return err } - conn.SetupPingHandler(request.Unset, stream.PingHandler{ + conn.SetupPingHandler(websocketRateLimitNotNeededEPL, stream.PingHandler{ Websocket: true, MessageType: websocket.PingMessage, Delay: time.Second * 15, diff --git a/exchanges/gateio/gateio_websocket_option.go b/exchanges/gateio/gateio_websocket_option.go index 40bc63fe336..091a9ad1f0f 100644 --- a/exchanges/gateio/gateio_websocket_option.go +++ b/exchanges/gateio/gateio_websocket_option.go @@ -18,7 +18,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" - "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" @@ -85,7 +84,7 @@ func (g *Gateio) WsOptionsConnect(ctx context.Context, conn stream.Connection) e if err != nil { return err } - conn.SetupPingHandler(request.Unset, stream.PingHandler{ + conn.SetupPingHandler(websocketRateLimitNotNeededEPL, stream.PingHandler{ Websocket: true, Delay: time.Second * 5, MessageType: websocket.PingMessage, diff --git a/exchanges/gateio/gateio_websocket_request_spot.go b/exchanges/gateio/gateio_websocket_request_spot.go index 52d11d0197b..44cb73efc78 100644 --- a/exchanges/gateio/gateio_websocket_request_spot.go +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -50,10 +50,10 @@ func (g *Gateio) WebsocketOrderPlaceSpot(ctx context.Context, orders []Websocket if len(orders) == 1 { var singleResponse WebsocketOrderResponse - return []WebsocketOrderResponse{singleResponse}, g.SendWebsocketRequest(ctx, "spot.order_place", asset.Spot, orders[0], &singleResponse, 2) + return []WebsocketOrderResponse{singleResponse}, g.SendWebsocketRequest(ctx, spotPlaceOrderEPL, "spot.order_place", asset.Spot, orders[0], &singleResponse, 2) } var resp []WebsocketOrderResponse - return resp, g.SendWebsocketRequest(ctx, "spot.order_place", asset.Spot, orders, &resp, 2) + return resp, g.SendWebsocketRequest(ctx, spotBatchOrdersEPL, "spot.order_place", asset.Spot, orders, &resp, 2) } // WebsocketOrderCancelSpot cancels an order via the websocket connection @@ -68,7 +68,7 @@ func (g *Gateio) WebsocketOrderCancelSpot(ctx context.Context, orderID string, p params := &WebsocketOrderRequest{OrderID: orderID, Pair: pair.String(), Account: account} var resp WebsocketOrderResponse - return &resp, g.SendWebsocketRequest(ctx, "spot.order_cancel", asset.Spot, params, &resp, 1) + return &resp, g.SendWebsocketRequest(ctx, spotCancelSingleOrderEPL, "spot.order_cancel", asset.Spot, params, &resp, 1) } // WebsocketOrderCancelAllByIDsSpot cancels multiple orders via the websocket @@ -87,7 +87,7 @@ func (g *Gateio) WebsocketOrderCancelAllByIDsSpot(ctx context.Context, o []Webso } var resp []WebsocketCancellAllResponse - return resp, g.SendWebsocketRequest(ctx, "spot.order_cancel_ids", asset.Spot, o, &resp, 2) + return resp, g.SendWebsocketRequest(ctx, spotCancelBatchOrdersEPL, "spot.order_cancel_ids", asset.Spot, o, &resp, 2) } // WebsocketOrderCancelAllByPairSpot cancels all orders for a specific pair @@ -109,7 +109,7 @@ func (g *Gateio) WebsocketOrderCancelAllByPairSpot(ctx context.Context, pair cur } var resp []WebsocketOrderResponse - return resp, g.SendWebsocketRequest(ctx, "spot.order_cancel_cp", asset.Spot, params, &resp, 1) + return resp, g.SendWebsocketRequest(ctx, spotCancelAllOpenOrdersEPL, "spot.order_cancel_cp", asset.Spot, params, &resp, 1) } // WebsocketOrderAmendSpot amends an order via the websocket connection @@ -131,7 +131,7 @@ func (g *Gateio) WebsocketOrderAmendSpot(ctx context.Context, amend *WebsocketAm } var resp WebsocketOrderResponse - return &resp, g.SendWebsocketRequest(ctx, "spot.order_amend", asset.Spot, amend, &resp, 1) + return &resp, g.SendWebsocketRequest(ctx, spotAmendOrderEPL, "spot.order_amend", asset.Spot, amend, &resp, 1) } // WebsocketGetOrderStatusSpot gets the status of an order via the websocket connection @@ -146,7 +146,7 @@ func (g *Gateio) WebsocketGetOrderStatusSpot(ctx context.Context, orderID string params := &WebsocketOrderRequest{OrderID: orderID, Pair: pair.String(), Account: account} var resp WebsocketOrderResponse - return &resp, g.SendWebsocketRequest(ctx, "spot.order_status", asset.Spot, params, &resp, 1) + return &resp, g.SendWebsocketRequest(ctx, spotGetOrdersEPL, "spot.order_status", asset.Spot, params, &resp, 1) } // funnelResult is used to unmarshal the result of a websocket request back to the required caller type @@ -155,7 +155,7 @@ type funnelResult struct { } // SendWebsocketRequest sends a websocket request to the exchange -func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel string, connSignature, params, result any, expectedResponses int) error { +func (g *Gateio) SendWebsocketRequest(ctx context.Context, epl request.EndpointLimit, channel string, connSignature, params, result any, expectedResponses int) error { paramPayload, err := json.Marshal(params) if err != nil { return err @@ -180,7 +180,7 @@ func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel string, connS }, } - responses, err := conn.SendMessageReturnResponsesWithInspector(ctx, request.Unset, req.Payload.RequestID, req, expectedResponses, wsRespAckInspector{}) + responses, err := conn.SendMessageReturnResponsesWithInspector(ctx, epl, req.Payload.RequestID, req, expectedResponses, wsRespAckInspector{}) if err != nil { return err } From 999a1089399609d3cb9fe0b5f2e1cf69b1f80609 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 12 Dec 2024 15:28:55 +1100 Subject: [PATCH 126/138] bitstamp: fix issue --- exchanges/bitstamp/bitstamp_websocket.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exchanges/bitstamp/bitstamp_websocket.go b/exchanges/bitstamp/bitstamp_websocket.go index de45f2fba95..f7154f2ba25 100644 --- a/exchanges/bitstamp/bitstamp_websocket.go +++ b/exchanges/bitstamp/bitstamp_websocket.go @@ -136,7 +136,7 @@ func (b *Bitstamp) handleWSSubscription(event string, respRaw []byte) error { } event = strings.TrimSuffix(event, "scription_succeeded") if !b.Websocket.Match.IncomingWithData(event+":"+channel, respRaw) { - return fmt.Errorf("%w: %s", stream.ErrNoMessageListener, event+":"+channel) + return fmt.Errorf("%w: %s", stream.ErrSignatureNotMatched, event+":"+channel) } return nil } From c0acea2a82575ad5d646945427dd537d363e08f5 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 12 Dec 2024 15:46:19 +1100 Subject: [PATCH 127/138] glorious: nits --- exchanges/bitfinex/bitfinex_websocket.go | 8 ++++---- exchanges/bitmex/bitmex_websocket.go | 2 +- exchanges/bitstamp/bitstamp_websocket.go | 5 +---- exchanges/gateio/gateio_websocket.go | 4 ++-- exchanges/kucoin/kucoin_websocket.go | 2 +- exchanges/stream/stream_match.go | 4 ++-- exchanges/stream/stream_match_test.go | 6 +++--- 7 files changed, 14 insertions(+), 17 deletions(-) diff --git a/exchanges/bitfinex/bitfinex_websocket.go b/exchanges/bitfinex/bitfinex_websocket.go index 0b213e92d33..ee1b854b0b0 100644 --- a/exchanges/bitfinex/bitfinex_websocket.go +++ b/exchanges/bitfinex/bitfinex_websocket.go @@ -456,18 +456,18 @@ func (b *Bitfinex) handleWSEvent(respRaw []byte) error { if err != nil { return fmt.Errorf("%w 'chanId': %w from message: %s", errParsingWSField, err, respRaw) } - err = b.Websocket.Match.EnsureMatchWithData("unsubscribe:"+chanID, respRaw) + err = b.Websocket.Match.RequireMatchWithData("unsubscribe:"+chanID, respRaw) if err != nil { return fmt.Errorf("%w: unsubscribe:%v", err, chanID) } case wsEventError: if subID, err := jsonparser.GetUnsafeString(respRaw, "subId"); err == nil { - err = b.Websocket.Match.EnsureMatchWithData("subscribe:"+subID, respRaw) + err = b.Websocket.Match.RequireMatchWithData("subscribe:"+subID, respRaw) if err != nil { return fmt.Errorf("%w: subscribe:%v", err, subID) } } else if chanID, err := jsonparser.GetUnsafeString(respRaw, "chanId"); err == nil { - err = b.Websocket.Match.EnsureMatchWithData("unsubscribe:"+chanID, respRaw) + err = b.Websocket.Match.RequireMatchWithData("unsubscribe:"+chanID, respRaw) if err != nil { return fmt.Errorf("%w: unsubscribe:%v", err, chanID) } @@ -543,7 +543,7 @@ func (b *Bitfinex) handleWSSubscribed(respRaw []byte) error { log.Debugf(log.ExchangeSys, "%s Subscribed to Channel: %s Pair: %s ChannelID: %d\n", b.Name, c.Channel, c.Pairs, chanID) } - return b.Websocket.Match.EnsureMatchWithData("subscribe:"+subID, respRaw) + return b.Websocket.Match.RequireMatchWithData("subscribe:"+subID, respRaw) } func (b *Bitfinex) handleWSChannelUpdate(s *subscription.Subscription, eventType string, d []interface{}) error { diff --git a/exchanges/bitmex/bitmex_websocket.go b/exchanges/bitmex/bitmex_websocket.go index eae34eea4cc..d322d5a6713 100644 --- a/exchanges/bitmex/bitmex_websocket.go +++ b/exchanges/bitmex/bitmex_websocket.go @@ -170,7 +170,7 @@ func (b *Bitmex) wsHandleData(respRaw []byte) error { if e2 != nil { return fmt.Errorf("%w parsing stream", e2) } - err = b.Websocket.Match.EnsureMatchWithData(op+":"+streamID, msg) + err = b.Websocket.Match.RequireMatchWithData(op+":"+streamID, msg) if err != nil { return fmt.Errorf("%w: %s:%s", err, op, streamID) } diff --git a/exchanges/bitstamp/bitstamp_websocket.go b/exchanges/bitstamp/bitstamp_websocket.go index f7154f2ba25..8fe4f2fc1fa 100644 --- a/exchanges/bitstamp/bitstamp_websocket.go +++ b/exchanges/bitstamp/bitstamp_websocket.go @@ -135,10 +135,7 @@ func (b *Bitstamp) handleWSSubscription(event string, respRaw []byte) error { return fmt.Errorf("%w `channel`: %w", errParsingWSField, err) } event = strings.TrimSuffix(event, "scription_succeeded") - if !b.Websocket.Match.IncomingWithData(event+":"+channel, respRaw) { - return fmt.Errorf("%w: %s", stream.ErrSignatureNotMatched, event+":"+channel) - } - return nil + return b.Websocket.Match.RequireMatchWithData(event+":"+channel, respRaw) } func (b *Bitstamp) handleWSTrade(msg []byte) error { diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index ac807c61f86..ea53f12029a 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -172,11 +172,11 @@ func (g *Gateio) WsHandleSpotData(_ context.Context, respRaw []byte) error { } if push.RequestID != "" { - return g.Websocket.Match.EnsureMatchWithData(push.RequestID, respRaw) + return g.Websocket.Match.RequireMatchWithData(push.RequestID, respRaw) } if push.Event == subscribeEvent || push.Event == unsubscribeEvent { - return g.Websocket.Match.EnsureMatchWithData(push.ID, respRaw) + return g.Websocket.Match.RequireMatchWithData(push.ID, respRaw) } switch push.Channel { // TODO: Convert function params below to only use push.Result diff --git a/exchanges/kucoin/kucoin_websocket.go b/exchanges/kucoin/kucoin_websocket.go index 956239e25a5..e364b2b84fc 100644 --- a/exchanges/kucoin/kucoin_websocket.go +++ b/exchanges/kucoin/kucoin_websocket.go @@ -222,7 +222,7 @@ func (ku *Kucoin) wsHandleData(respData []byte) error { return nil } if resp.ID != "" { - return ku.Websocket.Match.EnsureMatchWithData("msgID:"+resp.ID, respData) + return ku.Websocket.Match.RequireMatchWithData("msgID:"+resp.ID, respData) } topicInfo := strings.Split(resp.Topic, ":") switch topicInfo[0] { diff --git a/exchanges/stream/stream_match.go b/exchanges/stream/stream_match.go index 6d8949e08e6..430688a816b 100644 --- a/exchanges/stream/stream_match.go +++ b/exchanges/stream/stream_match.go @@ -51,9 +51,9 @@ func (m *Match) IncomingWithData(signature any, data []byte) bool { return true } -// EnsureMatchWithData validates that incoming data matches a request's signature. +// RequireMatchWithData validates that incoming data matches a request's signature. // If a match is found, the data is processed; otherwise, it returns an error. -func (m *Match) EnsureMatchWithData(signature any, data []byte) error { +func (m *Match) RequireMatchWithData(signature any, data []byte) error { if m.IncomingWithData(signature, data) { return nil } diff --git a/exchanges/stream/stream_match_test.go b/exchanges/stream/stream_match_test.go index 5a873ce80c6..ef7db92b617 100644 --- a/exchanges/stream/stream_match_test.go +++ b/exchanges/stream/stream_match_test.go @@ -52,17 +52,17 @@ func TestRemoveSignature(t *testing.T) { } } -func TestEnsureMatchWithData(t *testing.T) { +func TestRequireMatchWithData(t *testing.T) { t.Parallel() match := NewMatch() - err := match.EnsureMatchWithData("hello", []byte("world")) + err := match.RequireMatchWithData("hello", []byte("world")) require.ErrorIs(t, err, ErrSignatureNotMatched, "Should error on unmatched signature") assert.Contains(t, err.Error(), "world", "Should contain the data in the error message") assert.Contains(t, err.Error(), "hello", "Should contain the signature in the error message") ch, err := match.Set("hello", 1) require.NoError(t, err, "Set must not error") - err = match.EnsureMatchWithData("hello", []byte("world")) + err = match.RequireMatchWithData("hello", []byte("world")) require.NoError(t, err, "Should not error on matched signature") assert.Equal(t, "world", string(<-ch)) } From 254ad2bfb24cb934f7768c8eac2b4790e62c42f2 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 12 Dec 2024 15:55:06 +1100 Subject: [PATCH 128/138] ch name and commentary --- exchanges/stream/websocket.go | 18 +++++++++--------- exchanges/stream/websocket_test.go | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 1e716d9f222..737c0eadc7f 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -68,7 +68,7 @@ var ( errCannotChangeConnectionURL = errors.New("cannot change connection URL when using multi connection management") errExchangeConfigEmpty = errors.New("exchange config is empty") errCannotObtainOutboundConnection = errors.New("cannot obtain outbound connection") - errConnectionSignatureNotSet = errors.New("connection signature not set") + errMessageFilterNotSet = errors.New("message filter not set") errMessageFilterNotComparable = errors.New("message filter is not comparable") ) @@ -1263,15 +1263,15 @@ func signalReceived(ch chan struct{}) bool { } } -// GetConnection returns a connection by connection signature (defined in wrapper setup) for request and response -// handling in a multi connection context. -func (w *Websocket) GetConnection(connSignature any) (Connection, error) { +// GetConnection returns a connection by message filter (defined in exchange package _wrapper.go websocket connection) +// for request and response handling in a multi connection context. +func (w *Websocket) GetConnection(messageFilter any) (Connection, error) { if w == nil { return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, w) } - if connSignature == nil { - return nil, errConnectionSignatureNotSet + if messageFilter == nil { + return nil, errMessageFilterNotSet } w.m.Lock() @@ -1286,13 +1286,13 @@ func (w *Websocket) GetConnection(connSignature any) (Connection, error) { } for _, wrapper := range w.connectionManager { - if wrapper.Setup.MessageFilter == connSignature { + if wrapper.Setup.MessageFilter == messageFilter { if wrapper.Connection == nil { - return nil, fmt.Errorf("%s: %s %w: %v", w.exchangeName, wrapper.Setup.URL, ErrNotConnected, connSignature) + return nil, fmt.Errorf("%s: %s %w associated with message filter: '%v'", w.exchangeName, wrapper.Setup.URL, ErrNotConnected, messageFilter) } return wrapper.Connection, nil } } - return nil, fmt.Errorf("%s: %w: %v", w.exchangeName, ErrRequestRouteNotFound, connSignature) + return nil, fmt.Errorf("%s: %w associated with message filter: '%v'", w.exchangeName, ErrRequestRouteNotFound, messageFilter) } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 8ca6709a10e..1008de78aee 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -1509,7 +1509,7 @@ func TestGetConnection(t *testing.T) { ws = &Websocket{} _, err = ws.GetConnection(nil) - require.ErrorIs(t, err, errConnectionSignatureNotSet) + require.ErrorIs(t, err, errMessageFilterNotSet) _, err = ws.GetConnection("testURL") require.ErrorIs(t, err, errCannotObtainOutboundConnection) From e803f77640e05a9b7fd333a1e30a1eb19bf14ef5 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 12 Dec 2024 16:17:45 +1100 Subject: [PATCH 129/138] fix bug add test --- exchanges/stream/websocket_connection.go | 3 ++- exchanges/stream/websocket_test.go | 31 ++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 92aace90ef1..55fd71682e6 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -344,6 +344,7 @@ func (w *WebsocketConnection) waitForResponses(ctx context.Context, signature an defer timeout.Stop() resps := make([][]byte, 0, expected) +inspection: for range expected { select { case resp := <-ch: @@ -351,7 +352,7 @@ func (w *WebsocketConnection) waitForResponses(ctx context.Context, signature an // Checks recently received message to determine if this is in fact the final message in a sequence of messages. if messageInspector != nil && messageInspector.IsFinal(resp) { w.Match.RemoveSignature(signature) - return resps, nil + break inspection } case <-timeout.C: w.Match.RemoveSignature(signature) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 1008de78aee..06a71a78fbc 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -769,9 +769,36 @@ func TestSendMessageReturnResponse(t *testing.T) { assert.ErrorIs(t, err, ErrSignatureTimeout, "SendMessageReturnResponse should error when request ID not found") } -type inspection struct{} +func TestWaitForResponses(t *testing.T) { + t.Parallel() + dummy := &WebsocketConnection{ + ResponseMaxLimit: time.Nanosecond, + Match: NewMatch(), + } + _, err := dummy.waitForResponses(context.Background(), "silly", nil, 1, inspection{}) + require.ErrorIs(t, err, ErrSignatureTimeout) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = dummy.waitForResponses(ctx, "silly", nil, 1, inspection{}) + require.ErrorIs(t, err, context.Canceled) + + // test break early and hit verbose path + ch := make(chan []byte, 1) + ch <- []byte("hello") + ctx = request.WithVerbose(context.Background()) + dummy.ResponseMaxLimit = time.Second + got, err := dummy.waitForResponses(ctx, "silly", ch, 2, inspection{breakEarly: true}) + require.NoError(t, err, context.Canceled) + require.Len(t, got, 1) + assert.Equal(t, "hello", string(got[0])) +} + +type inspection struct { + breakEarly bool +} -func (inspection) IsFinal([]byte) bool { return false } +func (i inspection) IsFinal([]byte) bool { return i.breakEarly } type reporter struct { name string From a8d30db71bf1491dd7285812c044ef57584040aa Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 12 Dec 2024 16:20:05 +1100 Subject: [PATCH 130/138] rm a thing --- exchanges/stream/websocket_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 06a71a78fbc..e1eb5c8f408 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -789,7 +789,7 @@ func TestWaitForResponses(t *testing.T) { ctx = request.WithVerbose(context.Background()) dummy.ResponseMaxLimit = time.Second got, err := dummy.waitForResponses(ctx, "silly", ch, 2, inspection{breakEarly: true}) - require.NoError(t, err, context.Canceled) + require.NoError(t, err) require.Len(t, got, 1) assert.Equal(t, "hello", string(got[0])) } From 975a54c0e83550c5e7060aada52295fdd43ddb63 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 12 Dec 2024 16:38:08 +1100 Subject: [PATCH 131/138] fix test --- exchanges/stream/websocket_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index e1eb5c8f408..b6f3a762404 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -778,6 +778,7 @@ func TestWaitForResponses(t *testing.T) { _, err := dummy.waitForResponses(context.Background(), "silly", nil, 1, inspection{}) require.ErrorIs(t, err, ErrSignatureTimeout) + dummy.ResponseMaxLimit = time.Second ctx, cancel := context.WithCancel(context.Background()) cancel() _, err = dummy.waitForResponses(ctx, "silly", nil, 1, inspection{}) @@ -787,7 +788,7 @@ func TestWaitForResponses(t *testing.T) { ch := make(chan []byte, 1) ch <- []byte("hello") ctx = request.WithVerbose(context.Background()) - dummy.ResponseMaxLimit = time.Second + got, err := dummy.waitForResponses(ctx, "silly", ch, 2, inspection{breakEarly: true}) require.NoError(t, err) require.Len(t, got, 1) From 1bdcd16794962f45da738f57a6c76a7e6fbadc73 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 20 Dec 2024 07:13:51 +1100 Subject: [PATCH 132/138] Update engine/engine.go Co-authored-by: Adrian Gallagher --- engine/engine.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/engine.go b/engine/engine.go index 15d8abe2776..8ad14d87ed6 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -793,7 +793,7 @@ func (bot *Engine) LoadExchange(name string) error { if !bot.Settings.EnableExchangeHTTPRateLimiter { err = exch.DisableRateLimiter() if err != nil { - gctlog.Errorf(gctlog.ExchangeSys, "error disabling rate limiter for %s: %v", exch.GetName(), err) + gctlog.Errorf(gctlog.ExchangeSys, "%s error disabling rate limiter: %v", exch.GetName(), err) } else { gctlog.Warnf(gctlog.ExchangeSys, "%s rate limiting has been turned off", exch.GetName()) } From 8f9e9535b27452ef7328066f9560df02036963fa Mon Sep 17 00:00:00 2001 From: shazbert Date: Fri, 20 Dec 2024 07:19:07 +1100 Subject: [PATCH 133/138] thrasher: nits --- exchanges/gateio/gateio_websocket_request_spot_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/exchanges/gateio/gateio_websocket_request_spot_test.go b/exchanges/gateio/gateio_websocket_request_spot_test.go index 1ee75f34859..0ecc396e350 100644 --- a/exchanges/gateio/gateio_websocket_request_spot_test.go +++ b/exchanges/gateio/gateio_websocket_request_spot_test.go @@ -92,8 +92,10 @@ func TestWebsocketOrderCancelSpot(t *testing.T) { func TestWebsocketOrderCancelAllByIDsSpot(t *testing.T) { t.Parallel() + _, err := g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderBatchRequest{}) + require.ErrorIs(t, err, errNoOrdersToCancel) out := WebsocketOrderBatchRequest{} - _, err := g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderBatchRequest{out}) + _, err = g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderBatchRequest{out}) require.ErrorIs(t, err, order.ErrOrderIDNotSet) out.OrderID = "1337" _, err = g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderBatchRequest{out}) From 0d7b7dbe620503aa5f9be497c00521d8a0b9732c Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 20 Dec 2024 12:42:03 +1100 Subject: [PATCH 134/138] Update exchanges/stream/stream_match_test.go Co-authored-by: Adrian Gallagher --- exchanges/stream/stream_match_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exchanges/stream/stream_match_test.go b/exchanges/stream/stream_match_test.go index ef7db92b617..18c81484088 100644 --- a/exchanges/stream/stream_match_test.go +++ b/exchanges/stream/stream_match_test.go @@ -56,7 +56,7 @@ func TestRequireMatchWithData(t *testing.T) { t.Parallel() match := NewMatch() err := match.RequireMatchWithData("hello", []byte("world")) - require.ErrorIs(t, err, ErrSignatureNotMatched, "Should error on unmatched signature") + require.ErrorIs(t, err, ErrSignatureNotMatched, "Must error on unmatched signature") assert.Contains(t, err.Error(), "world", "Should contain the data in the error message") assert.Contains(t, err.Error(), "hello", "Should contain the signature in the error message") From d164878de4a71754cd51c460803ec0a1a45c59a0 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 20 Dec 2024 12:42:13 +1100 Subject: [PATCH 135/138] Update exchanges/stream/stream_match_test.go Co-authored-by: Adrian Gallagher --- exchanges/stream/stream_match_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exchanges/stream/stream_match_test.go b/exchanges/stream/stream_match_test.go index 18c81484088..8053cb37934 100644 --- a/exchanges/stream/stream_match_test.go +++ b/exchanges/stream/stream_match_test.go @@ -63,6 +63,6 @@ func TestRequireMatchWithData(t *testing.T) { ch, err := match.Set("hello", 1) require.NoError(t, err, "Set must not error") err = match.RequireMatchWithData("hello", []byte("world")) - require.NoError(t, err, "Should not error on matched signature") + require.NoError(t, err, "Must not error on matched signature") assert.Equal(t, "world", string(<-ch)) } From 544fbccdedbe9b76a80cde4383e3edfc1e184e49 Mon Sep 17 00:00:00 2001 From: shazbert Date: Fri, 20 Dec 2024 13:00:44 +1100 Subject: [PATCH 136/138] GK: nits rn websocket functions --- .../gateio/gateio_websocket_request_spot.go | 24 ++++---- .../gateio_websocket_request_spot_test.go | 60 +++++++++---------- 2 files changed, 42 insertions(+), 42 deletions(-) diff --git a/exchanges/gateio/gateio_websocket_request_spot.go b/exchanges/gateio/gateio_websocket_request_spot.go index 44cb73efc78..d08d806bbc2 100644 --- a/exchanges/gateio/gateio_websocket_request_spot.go +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -22,9 +22,9 @@ var ( errChannelEmpty = errors.New("channel cannot be empty") ) -// WebsocketOrderPlaceSpot places an order via the websocket connection. You can +// WebsocketSpotSubmitOrder submits an order via the websocket connection. You can // send multiple orders in a single request. But only for one asset route. -func (g *Gateio) WebsocketOrderPlaceSpot(ctx context.Context, orders []WebsocketOrder) ([]WebsocketOrderResponse, error) { +func (g *Gateio) WebsocketSpotSubmitOrder(ctx context.Context, orders []WebsocketOrder) ([]WebsocketOrderResponse, error) { if len(orders) == 0 { return nil, errOrdersEmpty } @@ -56,8 +56,8 @@ func (g *Gateio) WebsocketOrderPlaceSpot(ctx context.Context, orders []Websocket return resp, g.SendWebsocketRequest(ctx, spotBatchOrdersEPL, "spot.order_place", asset.Spot, orders, &resp, 2) } -// WebsocketOrderCancelSpot cancels an order via the websocket connection -func (g *Gateio) WebsocketOrderCancelSpot(ctx context.Context, orderID string, pair currency.Pair, account string) (*WebsocketOrderResponse, error) { +// WebsocketSpotCancelOrder cancels an order via the websocket connection +func (g *Gateio) WebsocketSpotCancelOrder(ctx context.Context, orderID string, pair currency.Pair, account string) (*WebsocketOrderResponse, error) { if orderID == "" { return nil, order.ErrOrderIDNotSet } @@ -71,8 +71,8 @@ func (g *Gateio) WebsocketOrderCancelSpot(ctx context.Context, orderID string, p return &resp, g.SendWebsocketRequest(ctx, spotCancelSingleOrderEPL, "spot.order_cancel", asset.Spot, params, &resp, 1) } -// WebsocketOrderCancelAllByIDsSpot cancels multiple orders via the websocket -func (g *Gateio) WebsocketOrderCancelAllByIDsSpot(ctx context.Context, o []WebsocketOrderBatchRequest) ([]WebsocketCancellAllResponse, error) { +// WebsocketSpotCancelAllOrdersByIDs cancels multiple orders via the websocket +func (g *Gateio) WebsocketSpotCancelAllOrdersByIDs(ctx context.Context, o []WebsocketOrderBatchRequest) ([]WebsocketCancellAllResponse, error) { if len(o) == 0 { return nil, errNoOrdersToCancel } @@ -90,8 +90,8 @@ func (g *Gateio) WebsocketOrderCancelAllByIDsSpot(ctx context.Context, o []Webso return resp, g.SendWebsocketRequest(ctx, spotCancelBatchOrdersEPL, "spot.order_cancel_ids", asset.Spot, o, &resp, 2) } -// WebsocketOrderCancelAllByPairSpot cancels all orders for a specific pair -func (g *Gateio) WebsocketOrderCancelAllByPairSpot(ctx context.Context, pair currency.Pair, side order.Side, account string) ([]WebsocketOrderResponse, error) { +// WebsocketSpotCancelAllOrdersByPair cancels all orders for a specific pair +func (g *Gateio) WebsocketSpotCancelAllOrdersByPair(ctx context.Context, pair currency.Pair, side order.Side, account string) ([]WebsocketOrderResponse, error) { if !pair.IsEmpty() && side == order.UnknownSide { // This case will cancel all orders for every pair, this can be introduced later return nil, fmt.Errorf("'%v' %w while pair is set", side, order.ErrSideIsInvalid) @@ -112,8 +112,8 @@ func (g *Gateio) WebsocketOrderCancelAllByPairSpot(ctx context.Context, pair cur return resp, g.SendWebsocketRequest(ctx, spotCancelAllOpenOrdersEPL, "spot.order_cancel_cp", asset.Spot, params, &resp, 1) } -// WebsocketOrderAmendSpot amends an order via the websocket connection -func (g *Gateio) WebsocketOrderAmendSpot(ctx context.Context, amend *WebsocketAmendOrder) (*WebsocketOrderResponse, error) { +// WebsocketSpotAmendOrder amends an order via the websocket connection +func (g *Gateio) WebsocketSpotAmendOrder(ctx context.Context, amend *WebsocketAmendOrder) (*WebsocketOrderResponse, error) { if amend == nil { return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, amend) } @@ -134,8 +134,8 @@ func (g *Gateio) WebsocketOrderAmendSpot(ctx context.Context, amend *WebsocketAm return &resp, g.SendWebsocketRequest(ctx, spotAmendOrderEPL, "spot.order_amend", asset.Spot, amend, &resp, 1) } -// WebsocketGetOrderStatusSpot gets the status of an order via the websocket connection -func (g *Gateio) WebsocketGetOrderStatusSpot(ctx context.Context, orderID string, pair currency.Pair, account string) (*WebsocketOrderResponse, error) { +// WebsocketSpotGetOrderStatus gets the status of an order via the websocket connection +func (g *Gateio) WebsocketSpotGetOrderStatus(ctx context.Context, orderID string, pair currency.Pair, account string) (*WebsocketOrderResponse, error) { if orderID == "" { return nil, order.ErrOrderIDNotSet } diff --git a/exchanges/gateio/gateio_websocket_request_spot_test.go b/exchanges/gateio/gateio_websocket_request_spot_test.go index 0ecc396e350..0f5d3e0c0c3 100644 --- a/exchanges/gateio/gateio_websocket_request_spot_test.go +++ b/exchanges/gateio/gateio_websocket_request_spot_test.go @@ -36,21 +36,21 @@ func TestWebsocketLogin(t *testing.T) { require.NoError(t, err) } -func TestWebsocketOrderPlaceSpot(t *testing.T) { +func TestWebsocketSpotSubmitOrder(t *testing.T) { t.Parallel() - _, err := g.WebsocketOrderPlaceSpot(context.Background(), nil) + _, err := g.WebsocketSpotSubmitOrder(context.Background(), nil) require.ErrorIs(t, err, errOrdersEmpty) - _, err = g.WebsocketOrderPlaceSpot(context.Background(), make([]WebsocketOrder, 1)) + _, err = g.WebsocketSpotSubmitOrder(context.Background(), make([]WebsocketOrder, 1)) require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) out := WebsocketOrder{CurrencyPair: "BTC_USDT"} - _, err = g.WebsocketOrderPlaceSpot(context.Background(), []WebsocketOrder{out}) + _, err = g.WebsocketSpotSubmitOrder(context.Background(), []WebsocketOrder{out}) require.ErrorIs(t, err, order.ErrSideIsInvalid) out.Side = strings.ToLower(order.Buy.String()) - _, err = g.WebsocketOrderPlaceSpot(context.Background(), []WebsocketOrder{out}) + _, err = g.WebsocketSpotSubmitOrder(context.Background(), []WebsocketOrder{out}) require.ErrorIs(t, err, errInvalidAmount) out.Amount = "0.0003" out.Type = "limit" - _, err = g.WebsocketOrderPlaceSpot(context.Background(), []WebsocketOrder{out}) + _, err = g.WebsocketSpotSubmitOrder(context.Background(), []WebsocketOrder{out}) require.ErrorIs(t, err, errInvalidPrice) out.Price = "20000" @@ -60,21 +60,21 @@ func TestWebsocketOrderPlaceSpot(t *testing.T) { g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes // test single order - got, err := g.WebsocketOrderPlaceSpot(context.Background(), []WebsocketOrder{out}) + got, err := g.WebsocketSpotSubmitOrder(context.Background(), []WebsocketOrder{out}) require.NoError(t, err) require.NotEmpty(t, got) // test batch orders - got, err = g.WebsocketOrderPlaceSpot(context.Background(), []WebsocketOrder{out, out}) + got, err = g.WebsocketSpotSubmitOrder(context.Background(), []WebsocketOrder{out, out}) require.NoError(t, err) require.NotEmpty(t, got) } -func TestWebsocketOrderCancelSpot(t *testing.T) { +func TestWebsocketSpotCancelOrder(t *testing.T) { t.Parallel() - _, err := g.WebsocketOrderCancelSpot(context.Background(), "", currency.EMPTYPAIR, "") + _, err := g.WebsocketSpotCancelOrder(context.Background(), "", currency.EMPTYPAIR, "") require.ErrorIs(t, err, order.ErrOrderIDNotSet) - _, err = g.WebsocketOrderCancelSpot(context.Background(), "1337", currency.EMPTYPAIR, "") + _, err = g.WebsocketSpotCancelOrder(context.Background(), "1337", currency.EMPTYPAIR, "") require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) btcusdt, err := currency.NewPairFromString("BTC_USDT") @@ -85,20 +85,20 @@ func TestWebsocketOrderCancelSpot(t *testing.T) { testexch.UpdatePairsOnce(t, g) g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes - got, err := g.WebsocketOrderCancelSpot(context.Background(), "644913098758", btcusdt, "") + got, err := g.WebsocketSpotCancelOrder(context.Background(), "644913098758", btcusdt, "") require.NoError(t, err) require.NotEmpty(t, got) } -func TestWebsocketOrderCancelAllByIDsSpot(t *testing.T) { +func TestWebsocketSpotCancelAllOrdersByIDs(t *testing.T) { t.Parallel() - _, err := g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderBatchRequest{}) + _, err := g.WebsocketSpotCancelAllOrdersByIDs(context.Background(), []WebsocketOrderBatchRequest{}) require.ErrorIs(t, err, errNoOrdersToCancel) out := WebsocketOrderBatchRequest{} - _, err = g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderBatchRequest{out}) + _, err = g.WebsocketSpotCancelAllOrdersByIDs(context.Background(), []WebsocketOrderBatchRequest{out}) require.ErrorIs(t, err, order.ErrOrderIDNotSet) out.OrderID = "1337" - _, err = g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderBatchRequest{out}) + _, err = g.WebsocketSpotCancelAllOrdersByIDs(context.Background(), []WebsocketOrderBatchRequest{out}) require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) out.Pair, err = currency.NewPairFromString("BTC_USDT") @@ -110,17 +110,17 @@ func TestWebsocketOrderCancelAllByIDsSpot(t *testing.T) { g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes out.OrderID = "644913101755" - got, err := g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderBatchRequest{out}) + got, err := g.WebsocketSpotCancelAllOrdersByIDs(context.Background(), []WebsocketOrderBatchRequest{out}) require.NoError(t, err) require.NotEmpty(t, got) } -func TestWebsocketOrderCancelAllByPairSpot(t *testing.T) { +func TestWebsocketSpotCancelAllOrdersByPair(t *testing.T) { t.Parallel() pair, err := currency.NewPairFromString("LTC_USDT") require.NoError(t, err) - _, err = g.WebsocketOrderCancelAllByPairSpot(context.Background(), pair, 0, "") + _, err = g.WebsocketSpotCancelAllOrdersByPair(context.Background(), pair, 0, "") require.ErrorIs(t, err, order.ErrSideIsInvalid) sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) @@ -128,29 +128,29 @@ func TestWebsocketOrderCancelAllByPairSpot(t *testing.T) { testexch.UpdatePairsOnce(t, g) g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes - got, err := g.WebsocketOrderCancelAllByPairSpot(context.Background(), currency.EMPTYPAIR, order.Buy, "") + got, err := g.WebsocketSpotCancelAllOrdersByPair(context.Background(), currency.EMPTYPAIR, order.Buy, "") require.NoError(t, err) require.NotEmpty(t, got) } -func TestWebsocketOrderAmendSpot(t *testing.T) { +func TestWebsocketSpotAmendOrder(t *testing.T) { t.Parallel() - _, err := g.WebsocketOrderAmendSpot(context.Background(), nil) + _, err := g.WebsocketSpotAmendOrder(context.Background(), nil) require.ErrorIs(t, err, common.ErrNilPointer) amend := &WebsocketAmendOrder{} - _, err = g.WebsocketOrderAmendSpot(context.Background(), amend) + _, err = g.WebsocketSpotAmendOrder(context.Background(), amend) require.ErrorIs(t, err, order.ErrOrderIDNotSet) amend.OrderID = "1337" - _, err = g.WebsocketOrderAmendSpot(context.Background(), amend) + _, err = g.WebsocketSpotAmendOrder(context.Background(), amend) require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) amend.Pair, err = currency.NewPairFromString("BTC_USDT") require.NoError(t, err) - _, err = g.WebsocketOrderAmendSpot(context.Background(), amend) + _, err = g.WebsocketSpotAmendOrder(context.Background(), amend) require.ErrorIs(t, err, errInvalidAmount) amend.Amount = "0.0004" @@ -161,18 +161,18 @@ func TestWebsocketOrderAmendSpot(t *testing.T) { g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes amend.OrderID = "645029162673" - got, err := g.WebsocketOrderAmendSpot(context.Background(), amend) + got, err := g.WebsocketSpotAmendOrder(context.Background(), amend) require.NoError(t, err) require.NotEmpty(t, got) } -func TestWebsocketGetOrderStatusSpot(t *testing.T) { +func TestWebsocketSpotGetOrderStatus(t *testing.T) { t.Parallel() - _, err := g.WebsocketGetOrderStatusSpot(context.Background(), "", currency.EMPTYPAIR, "") + _, err := g.WebsocketSpotGetOrderStatus(context.Background(), "", currency.EMPTYPAIR, "") require.ErrorIs(t, err, order.ErrOrderIDNotSet) - _, err = g.WebsocketGetOrderStatusSpot(context.Background(), "1337", currency.EMPTYPAIR, "") + _, err = g.WebsocketSpotGetOrderStatus(context.Background(), "1337", currency.EMPTYPAIR, "") require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) @@ -183,7 +183,7 @@ func TestWebsocketGetOrderStatusSpot(t *testing.T) { pair, err := currency.NewPairFromString("BTC_USDT") require.NoError(t, err) - got, err := g.WebsocketGetOrderStatusSpot(context.Background(), "644999650452", pair, "") + got, err := g.WebsocketSpotGetOrderStatus(context.Background(), "644999650452", pair, "") require.NoError(t, err) require.NotEmpty(t, got) } From df476e213deadde70e426f439d1edf91db62ea6d Mon Sep 17 00:00:00 2001 From: shazbert Date: Fri, 20 Dec 2024 13:18:30 +1100 Subject: [PATCH 137/138] explicit function names for single to multi outbound orders --- .../gateio/gateio_websocket_request_spot.go | 9 ++++- .../gateio_websocket_request_spot_test.go | 40 +++++++++++++++---- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/exchanges/gateio/gateio_websocket_request_spot.go b/exchanges/gateio/gateio_websocket_request_spot.go index d08d806bbc2..77579a373a3 100644 --- a/exchanges/gateio/gateio_websocket_request_spot.go +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -22,9 +22,14 @@ var ( errChannelEmpty = errors.New("channel cannot be empty") ) -// WebsocketSpotSubmitOrder submits an order via the websocket connection. You can +// WebsocketSpotSubmitOrder submits an order via the websocket connection +func (g *Gateio) WebsocketSpotSubmitOrder(ctx context.Context, order WebsocketOrder) ([]WebsocketOrderResponse, error) { + return g.WebsocketSpotSubmitOrders(ctx, []WebsocketOrder{order}) +} + +// WebsocketSpotSubmitOrders submits orders via the websocket connection. You can // send multiple orders in a single request. But only for one asset route. -func (g *Gateio) WebsocketSpotSubmitOrder(ctx context.Context, orders []WebsocketOrder) ([]WebsocketOrderResponse, error) { +func (g *Gateio) WebsocketSpotSubmitOrders(ctx context.Context, orders []WebsocketOrder) ([]WebsocketOrderResponse, error) { if len(orders) == 0 { return nil, errOrdersEmpty } diff --git a/exchanges/gateio/gateio_websocket_request_spot_test.go b/exchanges/gateio/gateio_websocket_request_spot_test.go index 0f5d3e0c0c3..9e756d8348a 100644 --- a/exchanges/gateio/gateio_websocket_request_spot_test.go +++ b/exchanges/gateio/gateio_websocket_request_spot_test.go @@ -38,19 +38,45 @@ func TestWebsocketLogin(t *testing.T) { func TestWebsocketSpotSubmitOrder(t *testing.T) { t.Parallel() - _, err := g.WebsocketSpotSubmitOrder(context.Background(), nil) + _, err := g.WebsocketSpotSubmitOrder(context.Background(), WebsocketOrder{}) + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + out := WebsocketOrder{CurrencyPair: "BTC_USDT"} + _, err = g.WebsocketSpotSubmitOrder(context.Background(), out) + require.ErrorIs(t, err, order.ErrSideIsInvalid) + out.Side = strings.ToLower(order.Buy.String()) + _, err = g.WebsocketSpotSubmitOrder(context.Background(), out) + require.ErrorIs(t, err, errInvalidAmount) + out.Amount = "0.0003" + out.Type = "limit" + _, err = g.WebsocketSpotSubmitOrder(context.Background(), out) + require.ErrorIs(t, err, errInvalidPrice) + out.Price = "20000" + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + got, err := g.WebsocketSpotSubmitOrder(context.Background(), out) + require.NoError(t, err) + require.NotEmpty(t, got) +} + +func TestWebsocketSpotSubmitOrders(t *testing.T) { + t.Parallel() + _, err := g.WebsocketSpotSubmitOrders(context.Background(), nil) require.ErrorIs(t, err, errOrdersEmpty) - _, err = g.WebsocketSpotSubmitOrder(context.Background(), make([]WebsocketOrder, 1)) + _, err = g.WebsocketSpotSubmitOrders(context.Background(), make([]WebsocketOrder, 1)) require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) out := WebsocketOrder{CurrencyPair: "BTC_USDT"} - _, err = g.WebsocketSpotSubmitOrder(context.Background(), []WebsocketOrder{out}) + _, err = g.WebsocketSpotSubmitOrders(context.Background(), []WebsocketOrder{out}) require.ErrorIs(t, err, order.ErrSideIsInvalid) out.Side = strings.ToLower(order.Buy.String()) - _, err = g.WebsocketSpotSubmitOrder(context.Background(), []WebsocketOrder{out}) + _, err = g.WebsocketSpotSubmitOrders(context.Background(), []WebsocketOrder{out}) require.ErrorIs(t, err, errInvalidAmount) out.Amount = "0.0003" out.Type = "limit" - _, err = g.WebsocketSpotSubmitOrder(context.Background(), []WebsocketOrder{out}) + _, err = g.WebsocketSpotSubmitOrders(context.Background(), []WebsocketOrder{out}) require.ErrorIs(t, err, errInvalidPrice) out.Price = "20000" @@ -60,12 +86,12 @@ func TestWebsocketSpotSubmitOrder(t *testing.T) { g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes // test single order - got, err := g.WebsocketSpotSubmitOrder(context.Background(), []WebsocketOrder{out}) + got, err := g.WebsocketSpotSubmitOrders(context.Background(), []WebsocketOrder{out}) require.NoError(t, err) require.NotEmpty(t, got) // test batch orders - got, err = g.WebsocketSpotSubmitOrder(context.Background(), []WebsocketOrder{out, out}) + got, err = g.WebsocketSpotSubmitOrders(context.Background(), []WebsocketOrder{out, out}) require.NoError(t, err) require.NotEmpty(t, got) } From 2182ddb0f389b208b56216d586b9d916be8e0001 Mon Sep 17 00:00:00 2001 From: shazbert Date: Fri, 20 Dec 2024 13:22:37 +1100 Subject: [PATCH 138/138] linter: fix --- exchanges/gateio/gateio_websocket_request_spot.go | 4 ++-- exchanges/gateio/gateio_websocket_request_spot_test.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/exchanges/gateio/gateio_websocket_request_spot.go b/exchanges/gateio/gateio_websocket_request_spot.go index 77579a373a3..25ac1ea1689 100644 --- a/exchanges/gateio/gateio_websocket_request_spot.go +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -23,8 +23,8 @@ var ( ) // WebsocketSpotSubmitOrder submits an order via the websocket connection -func (g *Gateio) WebsocketSpotSubmitOrder(ctx context.Context, order WebsocketOrder) ([]WebsocketOrderResponse, error) { - return g.WebsocketSpotSubmitOrders(ctx, []WebsocketOrder{order}) +func (g *Gateio) WebsocketSpotSubmitOrder(ctx context.Context, order *WebsocketOrder) ([]WebsocketOrderResponse, error) { + return g.WebsocketSpotSubmitOrders(ctx, []WebsocketOrder{*order}) } // WebsocketSpotSubmitOrders submits orders via the websocket connection. You can diff --git a/exchanges/gateio/gateio_websocket_request_spot_test.go b/exchanges/gateio/gateio_websocket_request_spot_test.go index 9e756d8348a..7933c117294 100644 --- a/exchanges/gateio/gateio_websocket_request_spot_test.go +++ b/exchanges/gateio/gateio_websocket_request_spot_test.go @@ -38,9 +38,9 @@ func TestWebsocketLogin(t *testing.T) { func TestWebsocketSpotSubmitOrder(t *testing.T) { t.Parallel() - _, err := g.WebsocketSpotSubmitOrder(context.Background(), WebsocketOrder{}) + _, err := g.WebsocketSpotSubmitOrder(context.Background(), &WebsocketOrder{}) require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) - out := WebsocketOrder{CurrencyPair: "BTC_USDT"} + out := &WebsocketOrder{CurrencyPair: "BTC_USDT"} _, err = g.WebsocketSpotSubmitOrder(context.Background(), out) require.ErrorIs(t, err, order.ErrSideIsInvalid) out.Side = strings.ToLower(order.Buy.String())