From 9205394f6bceda53bcd1cdb76ace448fc1aaed95 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Wed, 21 Aug 2024 14:56:25 +0700 Subject: [PATCH] CoinbasePro: Add subscription templating --- exchanges/coinbasepro/coinbasepro.go | 1 - exchanges/coinbasepro/coinbasepro_test.go | 53 ++++- .../coinbasepro/coinbasepro_websocket.go | 201 ++++++++++-------- exchanges/coinbasepro/coinbasepro_wrapper.go | 12 +- exchanges/subscription/subscription.go | 3 + 5 files changed, 159 insertions(+), 111 deletions(-) diff --git a/exchanges/coinbasepro/coinbasepro.go b/exchanges/coinbasepro/coinbasepro.go index ecf31ac6cfc..fcf9e5bc626 100644 --- a/exchanges/coinbasepro/coinbasepro.go +++ b/exchanges/coinbasepro/coinbasepro.go @@ -160,7 +160,6 @@ var ( errPairEmpty = errors.New("pair cannot be empty") errStringConvert = errors.New("unable to convert into string value") errFloatConvert = errors.New("unable to convert into float64 value") - errNoCredsUser = errors.New("no credentials when attempting to subscribe to authenticated channel user") errWrappedAssetEmpty = errors.New("wrapped asset cannot be empty") errExpectedOneTickerReturned = errors.New("expected one ticker to be returned") ) diff --git a/exchanges/coinbasepro/coinbasepro_test.go b/exchanges/coinbasepro/coinbasepro_test.go index e227e788e81..ce7aacdfde0 100644 --- a/exchanges/coinbasepro/coinbasepro_test.go +++ b/exchanges/coinbasepro/coinbasepro_test.go @@ -29,6 +29,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" + testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" gctlog "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -1659,17 +1660,30 @@ func TestProcessSnapshotUpdate(t *testing.T) { assert.NoError(t, err) } -func TestGenerateDefaultSubscriptions(t *testing.T) { - comparison := subscription.List{{Channel: "heartbeats"}, {Channel: "status"}, {Channel: "ticker"}, - {Channel: "ticker_batch"}, {Channel: "candles"}, {Channel: "market_trades"}, {Channel: "level2"}} - for i := range comparison { - comparison[i].Pairs = currency.Pairs{ - currency.NewPairWithDelimiter(testCrypto.String(), testFiat.String(), "-")} - comparison[i].Asset = asset.Spot +func TestGenerateSubscriptions(t *testing.T) { + t.Parallel() + c := new(CoinbasePro) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + if err := testexch.Setup(c); err != nil { + log.Fatal(err) } - resp, err := c.generateSubscriptions() + c.Websocket.SetCanUseAuthenticatedEndpoints(true) + p, err := c.GetEnabledPairs(asset.Spot) require.NoError(t, err) - assert.ElementsMatch(t, comparison, resp) + exp := subscription.List{} + for _, baseSub := range defaultSubscriptions.Enabled() { + s := baseSub.Clone() + s.QualifiedChannel = subscriptionNames[s.Channel] + if s.Asset != asset.Empty { + s.Pairs = p + } + exp = append(exp, s) + } + subs, err := c.generateSubscriptions() + require.NoError(t, err) + testsubs.EqualLists(t, exp, subs) + + _, err = subscription.List{{Channel: "wibble"}}.ExpandTemplates(c) + assert.ErrorContains(t, err, "subscription channel not supported: wibble") } func TestSubscribeUnsubscribe(t *testing.T) { @@ -1920,3 +1934,24 @@ func testGetOneArg[G getOneArgResp](t *testing.T, f getOneArgAssertNotEmpty[G], assert.NoError(t, err) assert.NotEmpty(t, resp, errExpectedNonEmpty) } + +func TestCheckSubscriptions(t *testing.T) { + t.Parallel() + + c := &CoinbasePro{ + Base: exchange.Base{ + Config: &config.Exchange{ + Features: &config.FeaturesConfig{ + Subscriptions: subscription.List{ + {Enabled: true, Channel: "matches"}, + }, + }, + }, + Features: exchange.Features{}, + }, + } + + c.checkSubscriptions() + testsubs.EqualLists(t, defaultSubscriptions.Enabled(), c.Features.Subscriptions) + testsubs.EqualLists(t, defaultSubscriptions, c.Config.Features.Subscriptions) +} diff --git a/exchanges/coinbasepro/coinbasepro_websocket.go b/exchanges/coinbasepro/coinbasepro_websocket.go index f6b82f57099..62a450f05b2 100644 --- a/exchanges/coinbasepro/coinbasepro_websocket.go +++ b/exchanges/coinbasepro/coinbasepro_websocket.go @@ -13,6 +13,7 @@ import ( "net/http" "strconv" "strings" + "text/template" "time" "github.com/buger/jsonparser" @@ -20,8 +21,6 @@ import ( "github.com/pkg/errors" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" - "github.com/thrasher-corp/gocryptotrader/currency" - exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" @@ -35,6 +34,34 @@ const ( coinbaseproWebsocketURL = "wss://advanced-trade-ws.coinbase.com" ) +var subscriptionNames = map[string]string{ + subscription.HeartbeatChannel: "heartbeats", + subscription.TickerChannel: "ticker", + subscription.CandlesChannel: "candles", + subscription.AllTradesChannel: "market_trades", + subscription.OrderbookChannel: "level2", + subscription.MyAccountChannel: "user", + "status": "status", + "ticker_batch": "ticker_batch", + /* Not Implemented: + "futures_balance_summary": "futures_balance_summary", + */ +} + +var defaultSubscriptions = subscription.List{ + {Enabled: true, Channel: subscription.HeartbeatChannel}, + {Enabled: true, Channel: "status"}, + {Enabled: true, Asset: asset.Spot, Channel: subscription.TickerChannel}, + {Enabled: true, Asset: asset.Spot, Channel: subscription.CandlesChannel}, + {Enabled: true, Asset: asset.Spot, Channel: subscription.AllTradesChannel}, + {Enabled: true, Asset: asset.Spot, Channel: subscription.OrderbookChannel}, + {Enabled: true, Channel: subscription.MyAccountChannel, Authenticated: true}, + {Enabled: false, Asset: asset.Spot, Channel: "ticker_batch"}, + /* Not Implemented: + {Enabled: false, Asset: asset.Spot, Channel: "futures_balance_summary", Authenticated: true}, + */ +} + // WsConnect initiates a websocket connection func (c *CoinbasePro) WsConnect() error { if !c.Websocket.IsEnabled() || !c.IsEnabled() { @@ -300,61 +327,67 @@ func (c *CoinbasePro) ProcessUpdate(update *WebsocketOrderbookDataHolder, timest // GenerateDefaultSubscriptions Adds default subscriptions to websocket to be handled by ManageSubscriptions() func (c *CoinbasePro) generateSubscriptions() (subscription.List, error) { - var channels = []string{ - "heartbeats", - "status", - "ticker", - "ticker_batch", - "candles", - "market_trades", - "level2", - } - enabledPairs, err := c.GetEnabledPairs(asset.Spot) - if err != nil { - return nil, err - } - var subscriptions subscription.List - for i := range channels { - subscriptions = append(subscriptions, &subscription.Subscription{ - Channel: channels[i], - Pairs: enabledPairs, - Asset: asset.Spot, - }) - } - return subscriptions, nil + return c.Features.Subscriptions.ExpandTemplates(c) } -// Subscribe sends a websocket message to receive data from the channel -func (c *CoinbasePro) Subscribe(channelsToSubscribe subscription.List) error { - chanKeys := make(map[string]currency.Pairs) - for i := range channelsToSubscribe { - chanKeys[channelsToSubscribe[i].Channel] = - chanKeys[channelsToSubscribe[i].Channel].Add(channelsToSubscribe[i].Pairs...) - } - for s := range chanKeys { - err := c.sendRequest("subscribe", s, chanKeys[s]) - if err != nil { - return err +// GetSubscriptionTemplate returns a subscription channel template +func (c *CoinbasePro) GetSubscriptionTemplate(_ *subscription.Subscription) (*template.Template, error) { + return template.New("master.tmpl").Funcs(template.FuncMap{"channelName": channelName}).Parse(subTplText) +} + +// Subscribe sends a websocket message to receive data from a list of channels +func (c *CoinbasePro) Subscribe(subs subscription.List) error { + return c.ParallelChanOp(subs, func(subs subscription.List) error { return c.manageSubs("subscribe", subs) }, 1) +} + +// Unsubscribe sends a websocket message to stop receiving data from a list of channels +func (c *CoinbasePro) Unsubscribe(subs subscription.List) error { + return c.ParallelChanOp(subs, func(subs subscription.List) error { return c.manageSubs("unsubscribe", subs) }, 1) +} + +// manageSub subscribes or unsubscribes from a list of websocket channels +func (c *CoinbasePro) manageSubs(op string, subs subscription.List) error { + var errs error + subs, errs = subs.ExpandTemplates(c) + for _, s := range subs { + r := &WebsocketRequest{ + Type: op, + ProductIDs: s.Pairs.Strings(), + Channel: s.QualifiedChannel, + Timestamp: strconv.FormatInt(time.Now().Unix(), 10), + } + var err error + limitType := WSUnauthRate + if s.Authenticated { + limitType = WSAuthRate + err = c.signWsRequest(r) + } + if err == nil { + err = c.InitiateRateLimit(context.Background(), limitType) + } + if err == nil { + if err = c.Websocket.Conn.SendJSONMessage(r); err == nil { + err = c.Websocket.AddSuccessfulSubscriptions(s) + } } - time.Sleep(time.Millisecond * 10) + errs = common.AppendError(errs, err) } return nil } -// Unsubscribe sends a websocket message to stop receiving data from the channel -func (c *CoinbasePro) Unsubscribe(channelsToUnsubscribe subscription.List) error { - chanKeys := make(map[string]currency.Pairs) - for i := range channelsToUnsubscribe { - chanKeys[channelsToUnsubscribe[i].Channel] = - chanKeys[channelsToUnsubscribe[i].Channel].Add(channelsToUnsubscribe[i].Pairs...) +func (c *CoinbasePro) signWsRequest(r *WebsocketRequest) error { + creds, err := c.GetCredentials(context.Background()) + if err != nil { + return err } - for s := range chanKeys { - err := c.sendRequest("unsubscribe", s, chanKeys[s]) - if err != nil { - return err - } - time.Sleep(time.Millisecond * 10) + hmac, err := crypto.GetHMAC(crypto.HashSHA256, []byte(r.Timestamp+r.Channel+strings.Join(r.ProductIDs, ",")), []byte(creds.Secret)) + if err != nil { + return err } + // TODO: Implement JWT authentication once our REST implementation moves to it, or if there's + // an exchange-wide reform to enable multiple sets of authentication credentials + r.Key = creds.Key + r.Signature = hex.EncodeToString(hmac) return nil } @@ -421,51 +454,6 @@ func getTimestamp(rawData []byte) (time.Time, error) { return timestamp, nil } -// sendRequest is a helper function which sends a websocket message to the Coinbase server -func (c *CoinbasePro) sendRequest(msgType, channel string, productIDs currency.Pairs) error { - authenticated := true - creds, err := c.GetCredentials(context.Background()) - if err != nil { - if errors.Is(err, exchange.ErrCredentialsAreEmpty) || - errors.Is(err, exchange.ErrAuthenticationSupportNotEnabled) { - authenticated = false - if channel == "user" { - return errNoCredsUser - } - } else { - return err - } - } - n := strconv.FormatInt(time.Now().Unix(), 10) - req := WebsocketRequest{ - Type: msgType, - ProductIDs: productIDs.Strings(), - Channel: channel, - Timestamp: n, - } - if authenticated { - message := n + channel + productIDs.Join() - var hmac []byte - hmac, err = crypto.GetHMAC(crypto.HashSHA256, - []byte(message), - []byte(creds.Secret)) - if err != nil { - return err - } - // TODO: Implement JWT authentication once our REST implementation moves to it, or if there's - // an exchange-wide reform to enable multiple sets of authentication credentials - req.Key = creds.Key - req.Signature = hex.EncodeToString(hmac) - err = c.InitiateRateLimit(context.Background(), WSAuthRate) - } else { - err = c.InitiateRateLimit(context.Background(), WSUnauthRate) - } - if err != nil { - return fmt.Errorf("failed to rate limit websocket request: %w", err) - } - return c.Websocket.Conn.SendJSONMessage(req) -} - // processBidAskArray is a helper function that turns WebsocketOrderbookDataHolder into arrays // of bids and asks func processBidAskArray(data *WebsocketOrderbookDataHolder) (bids, asks orderbook.Tranches, err error) { @@ -515,3 +503,30 @@ func base64URLEncode(b []byte) string { s = strings.ReplaceAll(s, "/", "_") return s } + +// checkSubscriptions looks for incompatible subscriptions and if found replaces all with defaults +// This should be unnecessary and removable by mid-2025 +func (c *CoinbasePro) checkSubscriptions() { + for _, s := range c.Config.Features.Subscriptions { + switch s.Channel { + case "heartbeat", "level2_batch", "matches": + c.Config.Features.Subscriptions = defaultSubscriptions.Clone() + c.Features.Subscriptions = c.Config.Features.Subscriptions.Enabled() + return + } + } +} + +func channelName(s *subscription.Subscription) string { + if n, ok := subscriptionNames[s.Channel]; ok { + return n + } + panic(fmt.Errorf("%w: %s", subscription.ErrNotSupported, s.Channel)) +} + +const subTplText = ` +{{ range $asset, $pairs := $.AssetPairs }} + {{- channelName $.S -}} + {{- $.AssetSeparator }} +{{- end }} +` diff --git a/exchanges/coinbasepro/coinbasepro_wrapper.go b/exchanges/coinbasepro/coinbasepro_wrapper.go index 37e3eb7a208..539f5989e39 100644 --- a/exchanges/coinbasepro/coinbasepro_wrapper.go +++ b/exchanges/coinbasepro/coinbasepro_wrapper.go @@ -25,7 +25,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" - "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" "github.com/thrasher-corp/gocryptotrader/log" @@ -106,13 +105,7 @@ func (c *CoinbasePro) SetDefaults() { GlobalResultLimit: 300, }, }, - Subscriptions: subscription.List{ - {Enabled: true, Channel: "heartbeat"}, - {Enabled: true, Channel: "level2_batch"}, // Other orderbook feeds require authentication; This is batched in 50ms lots - {Enabled: true, Channel: "ticker"}, - {Enabled: true, Channel: "user", Authenticated: true}, - {Enabled: true, Channel: "matches"}, - }, + Subscriptions: defaultSubscriptions.Clone(), } c.Requester, err = request.New(c.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), @@ -150,6 +143,9 @@ func (c *CoinbasePro) Setup(exch *config.Exchange) error { if err != nil { return err } + + c.checkSubscriptions() + wsRunningURL, err := c.API.Endpoints.GetURL(exchange.WebsocketSpot) if err != nil { return err diff --git a/exchanges/subscription/subscription.go b/exchanges/subscription/subscription.go index 516466363eb..64db2c3e415 100644 --- a/exchanges/subscription/subscription.go +++ b/exchanges/subscription/subscription.go @@ -31,6 +31,8 @@ const ( AllTradesChannel = "allTrades" MyTradesChannel = "myTrades" MyOrdersChannel = "myOrders" + MyAccountChannel = "account" + HeartbeatChannel = "heartbeat" ) // Public errors @@ -40,6 +42,7 @@ var ( ErrInStateAlready = errors.New("subscription already in state") ErrInvalidState = errors.New("invalid subscription state") ErrDuplicate = errors.New("duplicate subscription") + ErrNotSupported = errors.New("subscription channel not supported") ) // State tracks the status of a subscription channel