From 8e67bae0f06c16d01197b989b1e1787f641345d5 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Mon, 19 Aug 2024 12:26:18 +0700 Subject: [PATCH] BTCMarkets: Add subscription conf --- exchanges/btcmarkets/btcmarkets_test.go | 32 +++++++ exchanges/btcmarkets/btcmarkets_websocket.go | 89 ++++++++++++-------- exchanges/btcmarkets/btcmarkets_wrapper.go | 3 +- exchanges/subscription/subscription.go | 2 + internal/testing/exchange/exchange.go | 8 +- 5 files changed, 93 insertions(+), 41 deletions(-) diff --git a/exchanges/btcmarkets/btcmarkets_test.go b/exchanges/btcmarkets/btcmarkets_test.go index 41054b871e9..dc3b1204da6 100644 --- a/exchanges/btcmarkets/btcmarkets_test.go +++ b/exchanges/btcmarkets/btcmarkets_test.go @@ -19,7 +19,9 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" + "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" + testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" ) var b = &BTCMarkets{} @@ -1134,3 +1136,33 @@ func TestGetCurrencyTradeURL(t *testing.T) { assert.NotEmpty(t, resp) } } + +func TestGenerateSubscriptions(t *testing.T) { + t.Parallel() + b := new(BTCMarkets) + require.NoError(t, testexch.Setup(b), "Test instance Setup must not error") + p := currency.Pairs{currency.NewPairWithDelimiter("BTC", "USD", "_"), currency.NewPairWithDelimiter("ETH", "BTC", "_")} + require.NoError(t, b.CurrencyPairs.StorePairs(asset.Spot, p, false)) + require.NoError(t, b.CurrencyPairs.StorePairs(asset.Spot, p, true)) + b.Websocket.SetCanUseAuthenticatedEndpoints(true) + require.True(t, b.Websocket.CanUseAuthenticatedEndpoints(), "CanUseAuthenticatedEndpoints must return true") + subs, err := b.generateSubscriptions() + require.NoError(t, err, "generateSubscriptions should not error") + pairs, err := b.GetEnabledPairs(asset.Spot) + require.NoError(t, err, "GetEnabledPairs must not error") + exp := subscription.List{} + for _, baseSub := range b.Features.Subscriptions { + s := baseSub.Clone() + if !s.Authenticated && s.Channel != subscription.HeartbeatChannel { + s.Pairs = pairs + } + s.QualifiedChannel = channelName(s) + exp = append(exp, s) + } + testsubs.EqualLists(t, exp, subs) + assert.PanicsWithError(t, + "subscription channel not supported: wibble", + func() { channelName(&subscription.Subscription{Channel: "wibble"}) }, + "should panic on invalid channel", + ) +} diff --git a/exchanges/btcmarkets/btcmarkets_websocket.go b/exchanges/btcmarkets/btcmarkets_websocket.go index 333f77f06ba..742d7f4b711 100644 --- a/exchanges/btcmarkets/btcmarkets_websocket.go +++ b/exchanges/btcmarkets/btcmarkets_websocket.go @@ -9,6 +9,7 @@ import ( "net/http" "strconv" "strings" + "text/template" "time" "github.com/gorilla/websocket" @@ -33,10 +34,26 @@ const ( var ( errTypeAssertionFailure = errors.New("type assertion failure") errChecksumFailure = errors.New("crc32 checksum failure") - - authChannels = []string{fundChange, heartbeat, orderChange} ) +var defaultSubscriptions = subscription.List{ + {Enabled: true, Asset: asset.Spot, Channel: subscription.TickerChannel}, + {Enabled: true, Asset: asset.Spot, Channel: subscription.OrderbookChannel}, + {Enabled: true, Asset: asset.Spot, Channel: subscription.AllTradesChannel}, + {Enabled: true, Channel: subscription.MyOrdersChannel, Authenticated: true}, + {Enabled: true, Channel: subscription.MyAccountChannel, Authenticated: true}, + {Enabled: true, Channel: subscription.HeartbeatChannel}, +} + +var subscriptionNames = map[string]string{ + subscription.OrderbookChannel: wsOrderbookUpdate, + subscription.TickerChannel: tick, + subscription.AllTradesChannel: tradeEndPoint, + subscription.MyOrdersChannel: orderChange, + subscription.MyAccountChannel: fundChange, + subscription.HeartbeatChannel: heartbeat, +} + // WsConnect connects to a websocket feed func (b *BTCMarkets) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { @@ -326,29 +343,13 @@ func (b *BTCMarkets) wsHandleData(respRaw []byte) error { return nil } -func (b *BTCMarkets) generateDefaultSubscriptions() (subscription.List, error) { - var channels = []string{wsOrderbookUpdate, tick, tradeEndPoint} - enabledCurrencies, err := b.GetEnabledPairs(asset.Spot) - if err != nil { - return nil, err - } - var subscriptions subscription.List - for i := range channels { - subscriptions = append(subscriptions, &subscription.Subscription{ - Channel: channels[i], - Pairs: enabledCurrencies, - Asset: asset.Spot, - }) - } +func (b *BTCMarkets) generateSubscriptions() (subscription.List, error) { + return b.Features.Subscriptions.ExpandTemplates(b) +} - if b.Websocket.CanUseAuthenticatedEndpoints() { - for i := range authChannels { - subscriptions = append(subscriptions, &subscription.Subscription{ - Channel: authChannels[i], - }) - } - } - return subscriptions, nil +// GetSubscriptionTemplate returns a subscription channel template +func (b *BTCMarkets) GetSubscriptionTemplate(_ *subscription.Subscription) (*template.Template, error) { + return template.New("master.tmpl").Funcs(template.FuncMap{"channelName": channelName}).Parse(subTplText) } // Subscribe sends a websocket message to receive data from the channel @@ -358,13 +359,17 @@ func (b *BTCMarkets) Subscribe(subs subscription.List) error { } var errs error - for _, s := range subs { - if baseReq.Key == "" && common.StringSliceContains(authChannels, s.Channel) { - if err := b.authWsSubscibeReq(baseReq); err != nil { - return err + if authed := subs.Private(); len(authed) > 0 { + if err := b.signWsReq(baseReq); err != nil { + errs = err + for _, s := range authed { + errs = common.AppendError(errs, fmt.Errorf("%w: %s", request.ErrAuthRequestFailed, s)) } + subs = subs.Public() } + } + for _, batch := range subs.GroupByPairs() { if baseReq.MessageType == subscribe && len(b.Websocket.GetSubscriptions()) != 0 { baseReq.MessageType = addSubscription // After first *successful* subscription API requires addSubscription baseReq.ClientType = clientType // Note: Only addSubscription requires/accepts clientType @@ -372,12 +377,15 @@ func (b *BTCMarkets) Subscribe(subs subscription.List) error { r := baseReq - r.Channels = []string{s.Channel} - r.MarketIDs = s.Pairs.Strings() + r.MarketIDs = batch[0].Pairs.Strings() + r.Channels = make([]string, len(batch)) + for i, s := range batch { + r.Channels[i] = s.QualifiedChannel + } err := b.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, r) if err == nil { - err = b.Websocket.AddSuccessfulSubscriptions(b.Websocket.Conn, s) + err = b.Websocket.AddSuccessfulSubscriptions(b.Websocket.Conn, batch...) } if err != nil { errs = common.AppendError(errs, err) @@ -387,7 +395,7 @@ func (b *BTCMarkets) Subscribe(subs subscription.List) error { return errs } -func (b *BTCMarkets) authWsSubscibeReq(r *WsSubscribe) error { +func (b *BTCMarkets) signWsReq(r *WsSubscribe) error { creds, err := b.GetCredentials(context.TODO()) if err != nil { return err @@ -471,11 +479,24 @@ func concat(liquidity orderbook.Tranches) string { return c } -// trim turns value into string, removes the decimal point and all the leading -// zeros. +// trim turns value into string, removes the decimal point and all the leading zeros func trim(value float64) string { valstr := strconv.FormatFloat(value, 'f', -1, 64) valstr = strings.ReplaceAll(valstr, ".", "") valstr = strings.TrimLeft(valstr, "0") return valstr } + +func channelName(s *subscription.Subscription) string { + if n, ok := subscriptionNames[s.Channel]; ok { + return n + } + panic(fmt.Errorf("%w: %s", subscription.ErrNotSupported, s.Channel)) +} + +const subTplText = ` +{{ range $asset, $pairs := $.AssetPairs }} + {{- channelName $.S -}} + {{ $.AssetSeparator }} +{{- end }} +` diff --git a/exchanges/btcmarkets/btcmarkets_wrapper.go b/exchanges/btcmarkets/btcmarkets_wrapper.go index 0e97b9e7f53..e8abb596a0a 100644 --- a/exchanges/btcmarkets/btcmarkets_wrapper.go +++ b/exchanges/btcmarkets/btcmarkets_wrapper.go @@ -111,6 +111,7 @@ func (b *BTCMarkets) SetDefaults() { GlobalResultLimit: 1000, }, }, + Subscriptions: defaultSubscriptions.Clone(), } b.Requester, err = request.New(b.Name, @@ -160,7 +161,7 @@ func (b *BTCMarkets) Setup(exch *config.Exchange) error { Connector: b.WsConnect, Subscriber: b.Subscribe, Unsubscriber: b.Unsubscribe, - GenerateSubscriptions: b.generateDefaultSubscriptions, + GenerateSubscriptions: b.generateSubscriptions, Features: &b.Features.Supports.WebsocketCapabilities, OrderbookBufferConfig: buffer.Config{ SortBuffer: true, diff --git a/exchanges/subscription/subscription.go b/exchanges/subscription/subscription.go index db76b370521..6b2cfc24ae1 100644 --- a/exchanges/subscription/subscription.go +++ b/exchanges/subscription/subscription.go @@ -33,6 +33,7 @@ const ( MyOrdersChannel = "myOrders" MyWalletChannel = "myWallet" MyAccountChannel = "myAccount" + HeartbeatChannel = "heartbeat" ) // Public errors @@ -44,6 +45,7 @@ var ( ErrInvalidState = errors.New("invalid subscription state") ErrDuplicate = errors.New("duplicate subscription") ErrUseConstChannelName = errors.New("must use standard channel name constants") + ErrNotSupported = errors.New("subscription channel not supported") ) // State tracks the status of a subscription channel diff --git a/internal/testing/exchange/exchange.go b/internal/testing/exchange/exchange.go index fe82b4f5435..3612edb2acf 100644 --- a/internal/testing/exchange/exchange.go +++ b/internal/testing/exchange/exchange.go @@ -3,7 +3,6 @@ package exchange import ( "bufio" "context" - "errors" "fmt" "log" "net/http" @@ -38,11 +37,8 @@ func Setup(e exchange.IBotExchange) error { if err != nil { return fmt.Errorf("LoadConfig() error: %w", err) } - parts := strings.Split(fmt.Sprintf("%T", e), ".") - if len(parts) != 2 { - return errors.New("unexpected parts splitting exchange type name") - } - eName := parts[1] + e.SetDefaults() + eName := e.GetName() exchConf, err := cfg.GetExchangeConfig(eName) if err != nil { return fmt.Errorf("GetExchangeConfig(`%s`) error: %w", eName, err)