Skip to content

Commit

Permalink
Bitstamp: Add subscription configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
gbjk committed Dec 4, 2024
1 parent b98e82d commit 7c9b678
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 64 deletions.
27 changes: 27 additions & 0 deletions exchanges/bitstamp/bitstamp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ import (
"github.com/thrasher-corp/gocryptotrader/exchanges/order"
"github.com/thrasher-corp/gocryptotrader/exchanges/orderbook"
"github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange"
testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions"
"github.com/thrasher-corp/gocryptotrader/portfolio/banking"
"github.com/thrasher-corp/gocryptotrader/portfolio/withdraw"
)
Expand Down Expand Up @@ -1035,3 +1037,28 @@ func TestGetCurrencyTradeURL(t *testing.T) {
assert.NotEmpty(t, resp)
}
}

func TestGenerateSubscriptions(t *testing.T) {
t.Parallel()
b.Websocket.SetCanUseAuthenticatedEndpoints(true)
require.True(t, b.Websocket.CanUseAuthenticatedEndpoints(), "CanUseAuthenticatedEndpoints must return true")
subs, err := b.generateSubscriptions()
require.NoError(t, err, "generateSubscriptions should not error")
exp := subscription.List{}
pairs, err := b.GetEnabledPairs(asset.Spot)
require.NoError(t, err, "GetEnabledPairs must not error")
for _, baseSub := range b.Features.Subscriptions {
for _, p := range pairs.Format(currency.PairFormat{Uppercase: false}) {
s := baseSub.Clone()
s.Pairs = currency.Pairs{p}
s.QualifiedChannel = channelName(s) + "_" + p.String()
exp = append(exp, s)
}
}
testsubs.EqualLists(t, exp, subs)
assert.PanicsWithError(t,
"subscription channel not supported: wibble",
func() { channelName(&subscription.Subscription{Channel: "wibble"}) },
"should panic on invalid channel",
)
}
131 changes: 68 additions & 63 deletions exchanges/bitstamp/bitstamp_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ import (
"net/http"
"strconv"
"strings"
"text/template"
"time"

"github.com/gorilla/websocket"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/currency"
exchange "github.com/thrasher-corp/gocryptotrader/exchanges"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
"github.com/thrasher-corp/gocryptotrader/exchanges/kline"
"github.com/thrasher-corp/gocryptotrader/exchanges/order"
"github.com/thrasher-corp/gocryptotrader/exchanges/orderbook"
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
Expand All @@ -31,17 +33,21 @@ const (

var (
hbMsg = []byte(`{"event":"bts:heartbeat"}`)
)

defaultSubChannels = []string{
bitstampAPIWSTrades,
bitstampAPIWSOrderbook,
}
var defaultSubscriptions = subscription.List{
{Enabled: true, Asset: asset.Spot, Channel: subscription.OrderbookChannel, Interval: kline.HundredMilliseconds},
{Enabled: true, Asset: asset.Spot, Channel: subscription.AllTradesChannel},
{Enabled: true, Asset: asset.Spot, Channel: subscription.MyOrdersChannel, Authenticated: true},
{Enabled: true, Asset: asset.Spot, Channel: subscription.MyTradesChannel, Authenticated: true},
}

defaultAuthSubChannels = []string{
bitstampAPIWSMyOrders,
bitstampAPIWSMyTrades,
}
)
var subscriptionNames = map[string]string{
subscription.OrderbookChannel: bitstampAPIWSOrderbook,
subscription.AllTradesChannel: bitstampAPIWSTrades,
subscription.MyOrdersChannel: bitstampAPIWSMyOrders,
subscription.MyTradesChannel: bitstampAPIWSMyTrades,
}

// WsConnect connects to a websocket feed
func (b *Bitstamp) WsConnect() error {
Expand Down Expand Up @@ -232,66 +238,40 @@ func (b *Bitstamp) handleWSOrder(wsResp *websocketResponse, msg []byte) error {
return nil
}

func (b *Bitstamp) generateDefaultSubscriptions() (subscription.List, error) {
enabledCurrencies, err := b.GetEnabledPairs(asset.Spot)
if err != nil {
return nil, err
}
var subscriptions subscription.List
for i := range enabledCurrencies {
p, err := b.FormatExchangeCurrency(enabledCurrencies[i], asset.Spot)
if err != nil {
return nil, err
}
for j := range defaultSubChannels {
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: defaultSubChannels[j] + "_" + p.String(),
Asset: asset.Spot,
Pairs: currency.Pairs{p},
})
}
if b.Websocket.CanUseAuthenticatedEndpoints() {
for j := range defaultAuthSubChannels {
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: defaultAuthSubChannels[j] + "_" + p.String(),
Asset: asset.Spot,
Pairs: currency.Pairs{p},
Params: map[string]interface{}{
"auth": struct{}{},
},
})
}
}
}
return subscriptions, nil
func (b *Bitstamp) generateSubscriptions() (subscription.List, error) {
return b.Features.Subscriptions.ExpandTemplates(b)
}

// Subscribe sends a websocket message to receive data from the channel
func (b *Bitstamp) Subscribe(channelsToSubscribe subscription.List) error {
var errs error
var auth *WebsocketAuthResponse
// GetSubscriptionTemplate returns a subscription channel template
func (b *Bitstamp) GetSubscriptionTemplate(_ *subscription.Subscription) (*template.Template, error) {
return template.New("master.tmpl").Funcs(template.FuncMap{"channelName": channelName}).Parse(subTplText)
}

for i := range channelsToSubscribe {
if _, ok := channelsToSubscribe[i].Params["auth"]; ok {
var err error
auth, err = b.FetchWSAuth(context.TODO())
if err != nil {
errs = common.AppendError(errs, err)
}
break
}
// Subscribe sends a websocket message to receive data from a list of channels
func (b *Bitstamp) Subscribe(subs subscription.List) error {
var errs error
var creds *WebsocketAuthResponse
if authed := subs.Private(); len(authed) > 0 {
creds, errs = b.FetchWSAuth(context.TODO())
}
return common.AppendError(errs, b.ParallelChanOp(subs, func(s subscription.List) error { return b.subscribe(s, creds) }, 1))
}

for _, s := range channelsToSubscribe {
func (b *Bitstamp) subscribe(subs subscription.List, creds *WebsocketAuthResponse) error {
subs, errs := subs.ExpandTemplates(b)
for _, s := range subs {
req := websocketEventRequest{
Event: "bts:subscribe",
Data: websocketData{
Channel: s.Channel,
Channel: s.QualifiedChannel,
},
}
if _, ok := s.Params["auth"]; ok && auth != nil {
req.Data.Channel = "private-" + req.Data.Channel + "-" + strconv.Itoa(int(auth.UserID))
req.Data.Auth = auth.Token
if s.Authenticated {
if creds == nil {
return request.ErrAuthRequestFailed
}
req.Data.Channel = "private-" + req.Data.Channel + "-" + strconv.Itoa(int(creds.UserID))
req.Data.Auth = creds.Token
}
err := b.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, req)
if err == nil {
Expand All @@ -305,14 +285,18 @@ func (b *Bitstamp) Subscribe(channelsToSubscribe subscription.List) error {
return errs
}

// Unsubscribe sends a websocket message to stop receiving data from the channel
func (b *Bitstamp) Unsubscribe(channelsToUnsubscribe subscription.List) error {
// Unsubscribe sends a websocket message to stop receiving data from a list of channels
func (b *Bitstamp) Unsubscribe(subs subscription.List) error {
return b.ParallelChanOp(subs, b.unsubscribe, 1)
}

func (b *Bitstamp) unsubscribe(subs subscription.List) error {
var errs error
for _, s := range channelsToUnsubscribe {
for _, s := range subs {
req := websocketEventRequest{
Event: "bts:unsubscribe",
Data: websocketData{
Channel: s.Channel,
Channel: s.QualifiedChannel,
},
}
err := b.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, req)
Expand Down Expand Up @@ -459,3 +443,24 @@ func (b *Bitstamp) parseChannelName(r *websocketResponse) error {

return err
}

// channelName converts global channel Names to exchange specific ones
// panics if name is not supported, so should be called within a recover chain
func channelName(s *subscription.Subscription) string {
if s, ok := subscriptionNames[s.Channel]; ok {
return s
}
panic(fmt.Errorf("%w: %s", subscription.ErrNotSupported, s.Channel))
}

const subTplText = `
{{ range $asset, $pairs := $.AssetPairs }}
{{- with $name := channelName $.S }}
{{- range $p := $pairs -}}
{{- $name -}} _ {{- $p -}}
{{ $.PairSeparator }}
{{- end -}}
{{- end }}
{{ $.AssetSeparator }}
{{- end }}
`
3 changes: 2 additions & 1 deletion exchanges/bitstamp/bitstamp_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ func (b *Bitstamp) SetDefaults() {
GlobalResultLimit: 1000,
},
},
Subscriptions: defaultSubscriptions.Clone(),
}

b.Requester, err = request.New(b.Name,
Expand Down Expand Up @@ -156,7 +157,7 @@ func (b *Bitstamp) Setup(exch *config.Exchange) error {
Connector: b.WsConnect,
Subscriber: b.Subscribe,
Unsubscriber: b.Unsubscribe,
GenerateSubscriptions: b.generateDefaultSubscriptions,
GenerateSubscriptions: b.generateSubscriptions,
Features: &b.Features.Supports.WebsocketCapabilities,
})
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions exchanges/subscription/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ var (
ErrInvalidState = errors.New("invalid subscription state")
ErrDuplicate = errors.New("duplicate subscription")
ErrUseConstChannelName = errors.New("must use standard channel name constants")
ErrNotSupported = errors.New("subscription channel not supported")
)

// State tracks the status of a subscription channel
Expand Down

0 comments on commit 7c9b678

Please sign in to comment.