Skip to content

Commit

Permalink
Bybit: Add subscription configuration for spot (thrasher-corp#1601)
Browse files Browse the repository at this point in the history
* Bybit: Subscription configuration for spot

* Bybit: Enable candles ws sub by default

* Orderbook: Use a RW mutex for depth

* Orderbook: Fix race on depth.VerifyOrderbook

Despite being protected by an ob level mutex, this needed to privatise
and protect the option var.
  • Loading branch information
gbjk authored Nov 5, 2024
1 parent 8fe909d commit 199d30d
Show file tree
Hide file tree
Showing 11 changed files with 252 additions and 164 deletions.
2 changes: 1 addition & 1 deletion exchanges/bybit/bybit_inverse_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (by *Bybit) InverseUnsubscribe(channelSubscriptions subscription.List) erro
}

func (by *Bybit) handleInversePayloadSubscription(operation string, channelSubscriptions subscription.List) error {
payloads, err := by.handleSubscriptions(asset.CoinMarginedFutures, operation, channelSubscriptions)
payloads, err := by.handleSubscriptions(operation, channelSubscriptions)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion exchanges/bybit/bybit_linear_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (by *Bybit) LinearUnsubscribe(channelSubscriptions subscription.List) error
}

func (by *Bybit) handleLinearPayloadSubscription(operation string, channelSubscriptions subscription.List) error {
payloads, err := by.handleSubscriptions(asset.USDTMarginedFutures, operation, channelSubscriptions)
payloads, err := by.handleSubscriptions(operation, channelSubscriptions)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion exchanges/bybit/bybit_options_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (by *Bybit) OptionUnsubscribe(channelSubscriptions subscription.List) error
}

func (by *Bybit) handleOptionsPayloadSubscription(operation string, channelSubscriptions subscription.List) error {
payloads, err := by.handleSubscriptions(asset.Options, operation, channelSubscriptions)
payloads, err := by.handleSubscriptions(operation, channelSubscriptions)
if err != nil {
return err
}
Expand Down
95 changes: 95 additions & 0 deletions exchanges/bybit/bybit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"slices"
"testing"
"time"

"github.com/gofrs/uuid"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/key"
"github.com/thrasher-corp/gocryptotrader/currency"
exchange "github.com/thrasher-corp/gocryptotrader/exchanges"
Expand All @@ -22,8 +25,11 @@ import (
"github.com/thrasher-corp/gocryptotrader/exchanges/order"
"github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues"
"github.com/thrasher-corp/gocryptotrader/exchanges/stream"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
"github.com/thrasher-corp/gocryptotrader/exchanges/ticker"
testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange"
testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions"
testws "github.com/thrasher-corp/gocryptotrader/internal/testing/websocket"
"github.com/thrasher-corp/gocryptotrader/portfolio/withdraw"
"github.com/thrasher-corp/gocryptotrader/types"
)
Expand Down Expand Up @@ -3683,3 +3689,92 @@ func TestGetCurrencyTradeURL(t *testing.T) {
assert.NotEmpty(t, resp)
}
}

// TestGenerateSubscriptions exercises generateSubscriptions
func TestGenerateSubscriptions(t *testing.T) {
t.Parallel()

b := new(Bybit)
require.NoError(t, testexch.Setup(b), "Test instance Setup must not error")

b.Websocket.SetCanUseAuthenticatedEndpoints(true)
subs, err := b.generateSubscriptions()
require.NoError(t, err, "generateSubscriptions must not error")
exp := subscription.List{}
for _, s := range b.Features.Subscriptions {
for _, a := range b.GetAssetTypes(true) {
if s.Asset != asset.All && s.Asset != a {
continue
}
pairs, err := b.GetEnabledPairs(a)
require.NoErrorf(t, err, "GetEnabledPairs %s must not error", a)
pairs = common.SortStrings(pairs).Format(currency.PairFormat{Uppercase: true, Delimiter: ""})
s := s.Clone() //nolint:govet // Intentional lexical scope shadow
s.Asset = a
if isSymbolChannel(channelName(s)) {
for i, p := range pairs {
s := s.Clone() //nolint:govet // Intentional lexical scope shadow
switch s.Channel {
case subscription.CandlesChannel:
s.QualifiedChannel = fmt.Sprintf("%s.%.f.%s", channelName(s), s.Interval.Duration().Minutes(), p)
case subscription.OrderbookChannel:
s.QualifiedChannel = fmt.Sprintf("%s.%d.%s", channelName(s), s.Levels, p)
default:
s.QualifiedChannel = channelName(s) + "." + p.String()
}
s.Pairs = pairs[i : i+1]
exp = append(exp, s)
}
} else {
s.Pairs = pairs
s.QualifiedChannel = channelName(s)
exp = append(exp, s)
}
}
}
testsubs.EqualLists(t, exp, subs)
}

func TestSubscribe(t *testing.T) {
t.Parallel()
b := new(Bybit)
require.NoError(t, testexch.Setup(b), "Test instance Setup must not error")
subs, err := b.Features.Subscriptions.ExpandTemplates(b)
require.NoError(t, err, "ExpandTemplates must not error")
b.Features.Subscriptions = subscription.List{}
testexch.SetupWs(t, b)
err = b.Subscribe(subs)
require.NoError(t, err, "Subscribe must not error")
}

func TestAuthSubscribe(t *testing.T) {
t.Parallel()
b := new(Bybit)
require.NoError(t, testexch.Setup(b), "Test instance Setup must not error")
b.Websocket.SetCanUseAuthenticatedEndpoints(true)
subs, err := b.Features.Subscriptions.ExpandTemplates(b)
require.NoError(t, err, "ExpandTemplates must not error")
b.Features.Subscriptions = subscription.List{}
success := true
mock := func(tb testing.TB, msg []byte, w *websocket.Conn) error {
tb.Helper()
var req SubscriptionArgument
require.NoError(tb, json.Unmarshal(msg, &req), "Unmarshal must not error")
require.Equal(tb, "subscribe", req.Operation)
msg, err = json.Marshal(SubscriptionResponse{
Success: success,
RetMsg: "Mock Resp Error",
RequestID: req.RequestID,
Operation: req.Operation,
})
require.NoError(tb, err, "Marshal must not error")
return w.WriteMessage(websocket.TextMessage, msg)
}
b = testexch.MockWsInstance[Bybit](t, testws.CurryWsMockUpgrader(t, mock))
b.Websocket.AuthConn = b.Websocket.Conn
err = b.Subscribe(subs)
require.NoError(t, err, "Subscribe must not error")
success = false
err = b.Subscribe(subs)
assert.ErrorContains(t, err, "Mock Resp Error", "Subscribe should error containing the returned RetMsg")
}
204 changes: 97 additions & 107 deletions exchanges/bybit/bybit_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
"net/http"
"strconv"
"strings"
"text/template"
"time"

"github.com/gorilla/websocket"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/crypto"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/exchanges/account"
Expand Down Expand Up @@ -55,6 +57,27 @@ const (
websocketPrivate = "wss://stream.bybit.com/v5/private"
)

var defaultSubscriptions = subscription.List{
{Enabled: true, Asset: asset.Spot, Channel: subscription.TickerChannel},
{Enabled: true, Asset: asset.Spot, Channel: subscription.OrderbookChannel, Levels: 50},
{Enabled: true, Asset: asset.Spot, Channel: subscription.AllTradesChannel},
{Enabled: true, Asset: asset.Spot, Channel: subscription.CandlesChannel, Interval: kline.OneHour},
{Enabled: true, Asset: asset.Spot, Authenticated: true, Channel: subscription.MyOrdersChannel},
{Enabled: true, Asset: asset.Spot, Authenticated: true, Channel: subscription.MyWalletChannel},
{Enabled: true, Asset: asset.Spot, Authenticated: true, Channel: subscription.MyTradesChannel},
{Enabled: true, Asset: asset.Spot, Authenticated: true, Channel: chanPositions},
}

var subscriptionNames = map[string]string{
subscription.TickerChannel: chanPublicTicker,
subscription.OrderbookChannel: chanOrderbook,
subscription.AllTradesChannel: chanPublicTrade,
subscription.MyOrdersChannel: chanOrder,
subscription.MyTradesChannel: chanExecution,
subscription.MyWalletChannel: chanWallet,
subscription.CandlesChannel: chanKline,
}

// WsConnect connects to a websocket feed
func (by *Bybit) WsConnect() error {
if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.Spot) {
Expand Down Expand Up @@ -139,73 +162,36 @@ func (by *Bybit) Subscribe(channelsToSubscribe subscription.List) error {
return by.handleSpotSubscription("subscribe", channelsToSubscribe)
}

func (by *Bybit) handleSubscriptions(assetType asset.Item, operation string, channelsToSubscribe subscription.List) ([]SubscriptionArgument, error) {
var args []SubscriptionArgument
arg := SubscriptionArgument{
Operation: operation,
RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10),
Arguments: []string{},
}
authArg := SubscriptionArgument{
auth: true,
Operation: operation,
RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10),
Arguments: []string{},
}

var selectedChannels, positions, execution, order, wallet, greeks, dCP = 0, 1, 2, 3, 4, 5, 6
chanMap := map[string]int{
chanPositions: positions,
chanExecution: execution,
chanOrder: order,
chanWallet: wallet,
chanGreeks: greeks,
chanDCP: dCP}

pairFormat, err := by.GetPairFormat(assetType, true)
func (by *Bybit) handleSubscriptions(operation string, subs subscription.List) (args []SubscriptionArgument, err error) {
subs, err = subs.ExpandTemplates(by)
if err != nil {
return nil, err
}
for i := range channelsToSubscribe {
if len(channelsToSubscribe[i].Pairs) != 1 {
return nil, subscription.ErrNotSinglePair
}
pair := channelsToSubscribe[i].Pairs[0]
switch channelsToSubscribe[i].Channel {
case chanOrderbook:
arg.Arguments = append(arg.Arguments, fmt.Sprintf("%s.%d.%s", channelsToSubscribe[i].Channel, 50, pair.Format(pairFormat).String()))
case chanPublicTrade, chanPublicTicker, chanLiquidation, chanLeverageTokenTicker, chanLeverageTokenNav:
arg.Arguments = append(arg.Arguments, channelsToSubscribe[i].Channel+"."+pair.Format(pairFormat).String())
case chanKline, chanLeverageTokenKline:
interval, err := intervalToString(kline.FiveMin)
if err != nil {
return nil, err
}
arg.Arguments = append(arg.Arguments, channelsToSubscribe[i].Channel+"."+interval+"."+pair.Format(pairFormat).String())
case chanPositions, chanExecution, chanOrder, chanWallet, chanGreeks, chanDCP:
if chanMap[channelsToSubscribe[i].Channel]&selectedChannels > 0 {
continue
}
authArg.Arguments = append(authArg.Arguments, channelsToSubscribe[i].Channel)
// adding the channel to selected channels so that we will not visit it again.
selectedChannels |= chanMap[channelsToSubscribe[i].Channel]
}
if len(arg.Arguments) >= 10 {
args = append(args, arg)
arg = SubscriptionArgument{
Operation: operation,
RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10),
Arguments: []string{},
}
return
}
chans := []string{}
authChans := []string{}
for _, s := range subs {
if s.Authenticated {
authChans = append(authChans, s.QualifiedChannel)
} else {
chans = append(chans, s.QualifiedChannel)
}
}
if len(arg.Arguments) != 0 {
args = append(args, arg)
for _, b := range common.Batch(chans, 10) {
args = append(args, SubscriptionArgument{
Operation: operation,
RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10),
Arguments: b,
})
}
if len(authArg.Arguments) != 0 {
args = append(args, authArg)
if len(authChans) != 0 {
args = append(args, SubscriptionArgument{
auth: true,
Operation: operation,
RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10),
Arguments: authChans,
})
}
return args, nil
return
}

// Unsubscribe sends a websocket message to stop receiving data from the channel
Expand All @@ -214,7 +200,7 @@ func (by *Bybit) Unsubscribe(channelsToUnsubscribe subscription.List) error {
}

func (by *Bybit) handleSpotSubscription(operation string, channelsToSubscribe subscription.List) error {
payloads, err := by.handleSubscriptions(asset.Spot, operation, channelsToSubscribe)
payloads, err := by.handleSubscriptions(operation, channelsToSubscribe)
if err != nil {
return err
}
Expand Down Expand Up @@ -243,50 +229,18 @@ func (by *Bybit) handleSpotSubscription(operation string, channelsToSubscribe su
return nil
}

// GenerateDefaultSubscriptions generates default subscription
func (by *Bybit) GenerateDefaultSubscriptions() (subscription.List, error) {
var subscriptions subscription.List
var channels = []string{
chanPublicTicker,
chanOrderbook,
chanPublicTrade,
}
if by.Websocket.CanUseAuthenticatedEndpoints() {
channels = append(channels, []string{
chanPositions,
chanExecution,
chanOrder,
chanWallet,
}...)
}
pairs, err := by.GetEnabledPairs(asset.Spot)
if err != nil {
return nil, err
}
for x := range channels {
switch channels[x] {
case chanPositions,
chanExecution,
chanOrder,
chanDCP,
chanWallet:
subscriptions = append(subscriptions,
&subscription.Subscription{
Channel: channels[x],
Asset: asset.Spot,
})
default:
for z := range pairs {
subscriptions = append(subscriptions,
&subscription.Subscription{
Channel: channels[x],
Pairs: currency.Pairs{pairs[z]},
Asset: asset.Spot,
})
}
}
}
return subscriptions, nil
// generateSubscriptions generates default subscription
func (by *Bybit) generateSubscriptions() (subscription.List, error) {
return by.Features.Subscriptions.ExpandTemplates(by)
}

// GetSubscriptionTemplate returns a subscription channel template
func (by *Bybit) GetSubscriptionTemplate(_ *subscription.Subscription) (*template.Template, error) {
return template.New("master.tmpl").Funcs(template.FuncMap{
"channelName": channelName,
"isSymbolChannel": isSymbolChannel,
"intervalToString": intervalToString,
}).Parse(subTplText)
}

// wsReadData receives and passes on websocket messages for processing
Expand Down Expand Up @@ -788,3 +742,39 @@ func (by *Bybit) wsProcessOrderbook(assetType asset.Item, resp *WebsocketRespons
}
return nil
}

// channelName converts global channel names to exchange specific names
func channelName(s *subscription.Subscription) string {
if name, ok := subscriptionNames[s.Channel]; ok {
return name
}
return s.Channel
}

// isSymbolChannel returns whether the channel accepts a symbol parameter
func isSymbolChannel(name string) bool {
switch name {
case chanPositions, chanExecution, chanOrder, chanDCP, chanWallet:
return false
}
return true
}

const subTplText = `
{{ with $name := channelName $.S }}
{{- range $asset, $pairs := $.AssetPairs }}
{{- if isSymbolChannel $name }}
{{- range $p := $pairs }}
{{- $name -}} .
{{- if eq $name "orderbook" -}} {{- $.S.Levels -}} . {{- end }}
{{- if eq $name "kline" -}} {{- intervalToString $.S.Interval -}} . {{- end }}
{{- $p }}
{{- $.PairSeparator }}
{{- end }}
{{- else }}
{{- $name }}
{{- end }}
{{- end }}
{{- $.AssetSeparator }}
{{- end }}
`
Loading

0 comments on commit 199d30d

Please sign in to comment.