Skip to content

Commit

Permalink
BTCMarkets: Add subscription conf
Browse files Browse the repository at this point in the history
  • Loading branch information
gbjk committed Aug 20, 2024
1 parent ec283f2 commit 8fd7801
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 40 deletions.
32 changes: 32 additions & 0 deletions exchanges/btcmarkets/btcmarkets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ import (
"github.com/thrasher-corp/gocryptotrader/exchanges/order"
"github.com/thrasher-corp/gocryptotrader/exchanges/orderbook"
"github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange"
testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions"
)

var b = &BTCMarkets{}
Expand Down Expand Up @@ -1134,3 +1136,33 @@ func TestGetCurrencyTradeURL(t *testing.T) {
assert.NotEmpty(t, resp)
}
}

func TestGenerateSubscriptions(t *testing.T) {
t.Parallel()
b := new(BTCMarkets)
require.NoError(t, testexch.Setup(b), "Test instance Setup must not error")
p := currency.Pairs{currency.NewPairWithDelimiter("BTC", "USD", "_"), currency.NewPairWithDelimiter("ETH", "BTC", "_")}
require.NoError(t, b.CurrencyPairs.StorePairs(asset.Spot, p, false))
require.NoError(t, b.CurrencyPairs.StorePairs(asset.Spot, p, true))
b.Websocket.SetCanUseAuthenticatedEndpoints(true)
require.True(t, b.Websocket.CanUseAuthenticatedEndpoints(), "CanUseAuthenticatedEndpoints must return true")
subs, err := b.generateSubscriptions()
require.NoError(t, err, "generateSubscriptions should not error")
pairs, err := b.GetEnabledPairs(asset.Spot)
require.NoError(t, err, "GetEnabledPairs must not error")
exp := subscription.List{}
for _, baseSub := range b.Features.Subscriptions {
s := baseSub.Clone()
if !s.Authenticated && s.Channel != subscription.HeartbeatChannel {
s.Pairs = pairs
}
s.QualifiedChannel = channelName(s)
exp = append(exp, s)
}
testsubs.EqualLists(t, exp, subs)
assert.PanicsWithError(t,
"subscription channel not supported: wibble",
func() { channelName(&subscription.Subscription{Channel: "wibble"}) },
"should panic on invalid channel",
)
}
95 changes: 61 additions & 34 deletions exchanges/btcmarkets/btcmarkets_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http"
"strconv"
"strings"
"text/template"
"time"

"github.com/gorilla/websocket"
Expand All @@ -18,6 +19,7 @@ import (
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
"github.com/thrasher-corp/gocryptotrader/exchanges/order"
"github.com/thrasher-corp/gocryptotrader/exchanges/orderbook"
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
"github.com/thrasher-corp/gocryptotrader/exchanges/stream"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
"github.com/thrasher-corp/gocryptotrader/exchanges/ticker"
Expand All @@ -32,10 +34,26 @@ const (
var (
errTypeAssertionFailure = errors.New("type assertion failure")
errChecksumFailure = errors.New("crc32 checksum failure")

authChannels = []string{fundChange, heartbeat, orderChange}
)

var defaultSubscriptions = subscription.List{
{Enabled: true, Asset: asset.Spot, Channel: subscription.TickerChannel},
{Enabled: true, Asset: asset.Spot, Channel: subscription.OrderbookChannel},
{Enabled: true, Asset: asset.Spot, Channel: subscription.AllTradesChannel},
{Enabled: true, Channel: subscription.MyOrdersChannel, Authenticated: true},
{Enabled: true, Channel: subscription.MyAccountChannel, Authenticated: true},
{Enabled: true, Channel: subscription.HeartbeatChannel},
}

var subscriptionNames = map[string]string{
subscription.OrderbookChannel: wsOrderbookUpdate,
subscription.TickerChannel: tick,
subscription.AllTradesChannel: tradeEndPoint,
subscription.MyOrdersChannel: orderChange,
subscription.MyAccountChannel: fundChange,
subscription.HeartbeatChannel: heartbeat,
}

// WsConnect connects to a websocket feed
func (b *BTCMarkets) WsConnect() error {
if !b.Websocket.IsEnabled() || !b.IsEnabled() {
Expand Down Expand Up @@ -325,29 +343,13 @@ func (b *BTCMarkets) wsHandleData(respRaw []byte) error {
return nil
}

func (b *BTCMarkets) generateDefaultSubscriptions() (subscription.List, error) {
var channels = []string{wsOrderbookUpdate, tick, tradeEndPoint}
enabledCurrencies, err := b.GetEnabledPairs(asset.Spot)
if err != nil {
return nil, err
}
var subscriptions subscription.List
for i := range channels {
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: channels[i],
Pairs: enabledCurrencies,
Asset: asset.Spot,
})
}
func (b *BTCMarkets) generateSubscriptions() (subscription.List, error) {
return b.Features.Subscriptions.ExpandTemplates(b)
}

if b.Websocket.CanUseAuthenticatedEndpoints() {
for i := range authChannels {
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: authChannels[i],
})
}
}
return subscriptions, nil
// GetSubscriptionTemplate returns a subscription channel template
func (b *BTCMarkets) GetSubscriptionTemplate(_ *subscription.Subscription) (*template.Template, error) {
return template.New("master.tmpl").Funcs(template.FuncMap{"channelName": channelName}).Parse(subTplText)
}

// Subscribe sends a websocket message to receive data from the channel
Expand All @@ -357,26 +359,38 @@ func (b *BTCMarkets) Subscribe(subs subscription.List) error {
}

var errs error
for _, s := range subs {
if baseReq.Key == "" && common.StringDataContains(authChannels, s.Channel) {
if err := b.authWsSubscibeReq(baseReq); err != nil {
return err
if authed := subs.Authenticated(); len(authed) > 0 {
if err := b.signWsReq(baseReq); err != nil {
errs = err
n := subscription.List{}
for _, s := range subs {
if s.Authenticated {
errs = common.AppendError(errs, fmt.Errorf("%w: %s", request.ErrAuthRequestFailed, s))
} else {
n = append(n, s)
}
}
subs = n
}
}

for _, batch := range subs.GroupByPairs() {
if baseReq.MessageType == subscribe && len(b.Websocket.GetSubscriptions()) != 0 {
baseReq.MessageType = addSubscription // After first *successful* subscription API requires addSubscription
baseReq.ClientType = clientType // Note: Only addSubscription requires/accepts clientType
}

r := baseReq

r.Channels = []string{s.Channel}
r.MarketIDs = s.Pairs.Strings()
r.MarketIDs = batch[0].Pairs.Strings()
r.Channels = make([]string, len(batch))
for i, s := range batch {
r.Channels[i] = s.QualifiedChannel
}

err := b.Websocket.Conn.SendJSONMessage(r)
if err == nil {
err = b.Websocket.AddSuccessfulSubscriptions(s)
err = b.Websocket.AddSuccessfulSubscriptions(batch...)
}
if err != nil {
errs = common.AppendError(errs, err)
Expand All @@ -386,7 +400,7 @@ func (b *BTCMarkets) Subscribe(subs subscription.List) error {
return errs
}

func (b *BTCMarkets) authWsSubscibeReq(r *WsSubscribe) error {
func (b *BTCMarkets) signWsReq(r *WsSubscribe) error {
creds, err := b.GetCredentials(context.TODO())
if err != nil {
return err
Expand Down Expand Up @@ -470,11 +484,24 @@ func concat(liquidity orderbook.Tranches) string {
return c
}

// trim turns value into string, removes the decimal point and all the leading
// zeros.
// trim turns value into string, removes the decimal point and all the leading zeros
func trim(value float64) string {
valstr := strconv.FormatFloat(value, 'f', -1, 64)
valstr = strings.ReplaceAll(valstr, ".", "")
valstr = strings.TrimLeft(valstr, "0")
return valstr
}

func channelName(s *subscription.Subscription) string {
if n, ok := subscriptionNames[s.Channel]; ok {
return n
}
panic(fmt.Errorf("%w: %s", subscription.ErrNotSupported, s.Channel))
}

const subTplText = `
{{ range $asset, $pairs := $.AssetPairs }}
{{- channelName $.S -}}
{{ $.AssetSeparator }}
{{- end }}
`
3 changes: 3 additions & 0 deletions exchanges/subscription/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ const (
AllTradesChannel = "allTrades"
MyTradesChannel = "myTrades"
MyOrdersChannel = "myOrders"
MyAccountChannel = "account"
HeartbeatChannel = "heartbeat"
)

// Public errors
Expand All @@ -40,6 +42,7 @@ var (
ErrInStateAlready = errors.New("subscription already in state")
ErrInvalidState = errors.New("invalid subscription state")
ErrDuplicate = errors.New("duplicate subscription")
ErrNotSupported = errors.New("subscription channel not supported")
)

// State tracks the status of a subscription channel
Expand Down
8 changes: 2 additions & 6 deletions internal/testing/exchange/exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package exchange
import (
"bufio"
"context"
"errors"
"fmt"
"log"
"net/http"
Expand Down Expand Up @@ -39,11 +38,8 @@ func Setup(e exchange.IBotExchange) error {
if err != nil {
return fmt.Errorf("LoadConfig() error: %w", err)
}
parts := strings.Split(fmt.Sprintf("%T", e), ".")
if len(parts) != 2 {
return errors.New("unexpected parts splitting exchange type name")
}
eName := parts[1]
e.SetDefaults()
eName := e.GetName()
exchConf, err := cfg.GetExchangeConfig(eName)
if err != nil {
return fmt.Errorf("GetExchangeConfig(`%s`) error: %w", eName, err)
Expand Down

0 comments on commit 8fd7801

Please sign in to comment.