From 199d30d4420b550303a20c3b88462928587bdab9 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Tue, 5 Nov 2024 06:05:38 +0100 Subject: [PATCH] Bybit: Add subscription configuration for spot (#1601) * Bybit: Subscription configuration for spot * Bybit: Enable candles ws sub by default * Orderbook: Use a RW mutex for depth * Orderbook: Fix race on depth.VerifyOrderbook Despite being protected by an ob level mutex, this needed to privatise and protect the option var. --- exchanges/bybit/bybit_inverse_websocket.go | 2 +- exchanges/bybit/bybit_linear_websocket.go | 2 +- exchanges/bybit/bybit_options_websocket.go | 2 +- exchanges/bybit/bybit_test.go | 95 ++++++++++ exchanges/bybit/bybit_websocket.go | 204 ++++++++++----------- exchanges/bybit/bybit_wrapper.go | 3 +- exchanges/orderbook/depth.go | 88 ++++----- exchanges/orderbook/depth_test.go | 6 +- exchanges/orderbook/orderbook_types.go | 2 +- exchanges/stream/buffer/buffer.go | 11 +- exchanges/subscription/subscription.go | 1 + 11 files changed, 252 insertions(+), 164 deletions(-) diff --git a/exchanges/bybit/bybit_inverse_websocket.go b/exchanges/bybit/bybit_inverse_websocket.go index 1160fc2d8f5..0a0346bde58 100644 --- a/exchanges/bybit/bybit_inverse_websocket.go +++ b/exchanges/bybit/bybit_inverse_websocket.go @@ -66,7 +66,7 @@ func (by *Bybit) InverseUnsubscribe(channelSubscriptions subscription.List) erro } func (by *Bybit) handleInversePayloadSubscription(operation string, channelSubscriptions subscription.List) error { - payloads, err := by.handleSubscriptions(asset.CoinMarginedFutures, operation, channelSubscriptions) + payloads, err := by.handleSubscriptions(operation, channelSubscriptions) if err != nil { return err } diff --git a/exchanges/bybit/bybit_linear_websocket.go b/exchanges/bybit/bybit_linear_websocket.go index 303b2f76427..0edb1a322b3 100644 --- a/exchanges/bybit/bybit_linear_websocket.go +++ b/exchanges/bybit/bybit_linear_websocket.go @@ -84,7 +84,7 @@ func (by *Bybit) LinearUnsubscribe(channelSubscriptions subscription.List) error } func (by *Bybit) handleLinearPayloadSubscription(operation string, channelSubscriptions subscription.List) error { - payloads, err := by.handleSubscriptions(asset.USDTMarginedFutures, operation, channelSubscriptions) + payloads, err := by.handleSubscriptions(operation, channelSubscriptions) if err != nil { return err } diff --git a/exchanges/bybit/bybit_options_websocket.go b/exchanges/bybit/bybit_options_websocket.go index c67a611486c..e2e5836346b 100644 --- a/exchanges/bybit/bybit_options_websocket.go +++ b/exchanges/bybit/bybit_options_websocket.go @@ -73,7 +73,7 @@ func (by *Bybit) OptionUnsubscribe(channelSubscriptions subscription.List) error } func (by *Bybit) handleOptionsPayloadSubscription(operation string, channelSubscriptions subscription.List) error { - payloads, err := by.handleSubscriptions(asset.Options, operation, channelSubscriptions) + payloads, err := by.handleSubscriptions(operation, channelSubscriptions) if err != nil { return err } diff --git a/exchanges/bybit/bybit_test.go b/exchanges/bybit/bybit_test.go index 52362bfcc3f..74bbfac34c0 100644 --- a/exchanges/bybit/bybit_test.go +++ b/exchanges/bybit/bybit_test.go @@ -4,13 +4,16 @@ import ( "context" "encoding/json" "errors" + "fmt" "slices" "testing" "time" "github.com/gofrs/uuid" + "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/key" "github.com/thrasher-corp/gocryptotrader/currency" exchange "github.com/thrasher-corp/gocryptotrader/exchanges" @@ -22,8 +25,11 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" + "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" + testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" + testws "github.com/thrasher-corp/gocryptotrader/internal/testing/websocket" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" "github.com/thrasher-corp/gocryptotrader/types" ) @@ -3683,3 +3689,92 @@ func TestGetCurrencyTradeURL(t *testing.T) { assert.NotEmpty(t, resp) } } + +// TestGenerateSubscriptions exercises generateSubscriptions +func TestGenerateSubscriptions(t *testing.T) { + t.Parallel() + + b := new(Bybit) + require.NoError(t, testexch.Setup(b), "Test instance Setup must not error") + + b.Websocket.SetCanUseAuthenticatedEndpoints(true) + subs, err := b.generateSubscriptions() + require.NoError(t, err, "generateSubscriptions must not error") + exp := subscription.List{} + for _, s := range b.Features.Subscriptions { + for _, a := range b.GetAssetTypes(true) { + if s.Asset != asset.All && s.Asset != a { + continue + } + pairs, err := b.GetEnabledPairs(a) + require.NoErrorf(t, err, "GetEnabledPairs %s must not error", a) + pairs = common.SortStrings(pairs).Format(currency.PairFormat{Uppercase: true, Delimiter: ""}) + s := s.Clone() //nolint:govet // Intentional lexical scope shadow + s.Asset = a + if isSymbolChannel(channelName(s)) { + for i, p := range pairs { + s := s.Clone() //nolint:govet // Intentional lexical scope shadow + switch s.Channel { + case subscription.CandlesChannel: + s.QualifiedChannel = fmt.Sprintf("%s.%.f.%s", channelName(s), s.Interval.Duration().Minutes(), p) + case subscription.OrderbookChannel: + s.QualifiedChannel = fmt.Sprintf("%s.%d.%s", channelName(s), s.Levels, p) + default: + s.QualifiedChannel = channelName(s) + "." + p.String() + } + s.Pairs = pairs[i : i+1] + exp = append(exp, s) + } + } else { + s.Pairs = pairs + s.QualifiedChannel = channelName(s) + exp = append(exp, s) + } + } + } + testsubs.EqualLists(t, exp, subs) +} + +func TestSubscribe(t *testing.T) { + t.Parallel() + b := new(Bybit) + require.NoError(t, testexch.Setup(b), "Test instance Setup must not error") + subs, err := b.Features.Subscriptions.ExpandTemplates(b) + require.NoError(t, err, "ExpandTemplates must not error") + b.Features.Subscriptions = subscription.List{} + testexch.SetupWs(t, b) + err = b.Subscribe(subs) + require.NoError(t, err, "Subscribe must not error") +} + +func TestAuthSubscribe(t *testing.T) { + t.Parallel() + b := new(Bybit) + require.NoError(t, testexch.Setup(b), "Test instance Setup must not error") + b.Websocket.SetCanUseAuthenticatedEndpoints(true) + subs, err := b.Features.Subscriptions.ExpandTemplates(b) + require.NoError(t, err, "ExpandTemplates must not error") + b.Features.Subscriptions = subscription.List{} + success := true + mock := func(tb testing.TB, msg []byte, w *websocket.Conn) error { + tb.Helper() + var req SubscriptionArgument + require.NoError(tb, json.Unmarshal(msg, &req), "Unmarshal must not error") + require.Equal(tb, "subscribe", req.Operation) + msg, err = json.Marshal(SubscriptionResponse{ + Success: success, + RetMsg: "Mock Resp Error", + RequestID: req.RequestID, + Operation: req.Operation, + }) + require.NoError(tb, err, "Marshal must not error") + return w.WriteMessage(websocket.TextMessage, msg) + } + b = testexch.MockWsInstance[Bybit](t, testws.CurryWsMockUpgrader(t, mock)) + b.Websocket.AuthConn = b.Websocket.Conn + err = b.Subscribe(subs) + require.NoError(t, err, "Subscribe must not error") + success = false + err = b.Subscribe(subs) + assert.ErrorContains(t, err, "Mock Resp Error", "Subscribe should error containing the returned RetMsg") +} diff --git a/exchanges/bybit/bybit_websocket.go b/exchanges/bybit/bybit_websocket.go index 301a2c6cbe7..6106827c280 100644 --- a/exchanges/bybit/bybit_websocket.go +++ b/exchanges/bybit/bybit_websocket.go @@ -7,9 +7,11 @@ import ( "net/http" "strconv" "strings" + "text/template" "time" "github.com/gorilla/websocket" + "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/account" @@ -55,6 +57,27 @@ const ( websocketPrivate = "wss://stream.bybit.com/v5/private" ) +var defaultSubscriptions = subscription.List{ + {Enabled: true, Asset: asset.Spot, Channel: subscription.TickerChannel}, + {Enabled: true, Asset: asset.Spot, Channel: subscription.OrderbookChannel, Levels: 50}, + {Enabled: true, Asset: asset.Spot, Channel: subscription.AllTradesChannel}, + {Enabled: true, Asset: asset.Spot, Channel: subscription.CandlesChannel, Interval: kline.OneHour}, + {Enabled: true, Asset: asset.Spot, Authenticated: true, Channel: subscription.MyOrdersChannel}, + {Enabled: true, Asset: asset.Spot, Authenticated: true, Channel: subscription.MyWalletChannel}, + {Enabled: true, Asset: asset.Spot, Authenticated: true, Channel: subscription.MyTradesChannel}, + {Enabled: true, Asset: asset.Spot, Authenticated: true, Channel: chanPositions}, +} + +var subscriptionNames = map[string]string{ + subscription.TickerChannel: chanPublicTicker, + subscription.OrderbookChannel: chanOrderbook, + subscription.AllTradesChannel: chanPublicTrade, + subscription.MyOrdersChannel: chanOrder, + subscription.MyTradesChannel: chanExecution, + subscription.MyWalletChannel: chanWallet, + subscription.CandlesChannel: chanKline, +} + // WsConnect connects to a websocket feed func (by *Bybit) WsConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.Spot) { @@ -139,73 +162,36 @@ func (by *Bybit) Subscribe(channelsToSubscribe subscription.List) error { return by.handleSpotSubscription("subscribe", channelsToSubscribe) } -func (by *Bybit) handleSubscriptions(assetType asset.Item, operation string, channelsToSubscribe subscription.List) ([]SubscriptionArgument, error) { - var args []SubscriptionArgument - arg := SubscriptionArgument{ - Operation: operation, - RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10), - Arguments: []string{}, - } - authArg := SubscriptionArgument{ - auth: true, - Operation: operation, - RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10), - Arguments: []string{}, - } - - var selectedChannels, positions, execution, order, wallet, greeks, dCP = 0, 1, 2, 3, 4, 5, 6 - chanMap := map[string]int{ - chanPositions: positions, - chanExecution: execution, - chanOrder: order, - chanWallet: wallet, - chanGreeks: greeks, - chanDCP: dCP} - - pairFormat, err := by.GetPairFormat(assetType, true) +func (by *Bybit) handleSubscriptions(operation string, subs subscription.List) (args []SubscriptionArgument, err error) { + subs, err = subs.ExpandTemplates(by) if err != nil { - return nil, err - } - for i := range channelsToSubscribe { - if len(channelsToSubscribe[i].Pairs) != 1 { - return nil, subscription.ErrNotSinglePair - } - pair := channelsToSubscribe[i].Pairs[0] - switch channelsToSubscribe[i].Channel { - case chanOrderbook: - arg.Arguments = append(arg.Arguments, fmt.Sprintf("%s.%d.%s", channelsToSubscribe[i].Channel, 50, pair.Format(pairFormat).String())) - case chanPublicTrade, chanPublicTicker, chanLiquidation, chanLeverageTokenTicker, chanLeverageTokenNav: - arg.Arguments = append(arg.Arguments, channelsToSubscribe[i].Channel+"."+pair.Format(pairFormat).String()) - case chanKline, chanLeverageTokenKline: - interval, err := intervalToString(kline.FiveMin) - if err != nil { - return nil, err - } - arg.Arguments = append(arg.Arguments, channelsToSubscribe[i].Channel+"."+interval+"."+pair.Format(pairFormat).String()) - case chanPositions, chanExecution, chanOrder, chanWallet, chanGreeks, chanDCP: - if chanMap[channelsToSubscribe[i].Channel]&selectedChannels > 0 { - continue - } - authArg.Arguments = append(authArg.Arguments, channelsToSubscribe[i].Channel) - // adding the channel to selected channels so that we will not visit it again. - selectedChannels |= chanMap[channelsToSubscribe[i].Channel] - } - if len(arg.Arguments) >= 10 { - args = append(args, arg) - arg = SubscriptionArgument{ - Operation: operation, - RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10), - Arguments: []string{}, - } + return + } + chans := []string{} + authChans := []string{} + for _, s := range subs { + if s.Authenticated { + authChans = append(authChans, s.QualifiedChannel) + } else { + chans = append(chans, s.QualifiedChannel) } } - if len(arg.Arguments) != 0 { - args = append(args, arg) + for _, b := range common.Batch(chans, 10) { + args = append(args, SubscriptionArgument{ + Operation: operation, + RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10), + Arguments: b, + }) } - if len(authArg.Arguments) != 0 { - args = append(args, authArg) + if len(authChans) != 0 { + args = append(args, SubscriptionArgument{ + auth: true, + Operation: operation, + RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10), + Arguments: authChans, + }) } - return args, nil + return } // Unsubscribe sends a websocket message to stop receiving data from the channel @@ -214,7 +200,7 @@ func (by *Bybit) Unsubscribe(channelsToUnsubscribe subscription.List) error { } func (by *Bybit) handleSpotSubscription(operation string, channelsToSubscribe subscription.List) error { - payloads, err := by.handleSubscriptions(asset.Spot, operation, channelsToSubscribe) + payloads, err := by.handleSubscriptions(operation, channelsToSubscribe) if err != nil { return err } @@ -243,50 +229,18 @@ func (by *Bybit) handleSpotSubscription(operation string, channelsToSubscribe su return nil } -// GenerateDefaultSubscriptions generates default subscription -func (by *Bybit) GenerateDefaultSubscriptions() (subscription.List, error) { - var subscriptions subscription.List - var channels = []string{ - chanPublicTicker, - chanOrderbook, - chanPublicTrade, - } - if by.Websocket.CanUseAuthenticatedEndpoints() { - channels = append(channels, []string{ - chanPositions, - chanExecution, - chanOrder, - chanWallet, - }...) - } - pairs, err := by.GetEnabledPairs(asset.Spot) - if err != nil { - return nil, err - } - for x := range channels { - switch channels[x] { - case chanPositions, - chanExecution, - chanOrder, - chanDCP, - chanWallet: - subscriptions = append(subscriptions, - &subscription.Subscription{ - Channel: channels[x], - Asset: asset.Spot, - }) - default: - for z := range pairs { - subscriptions = append(subscriptions, - &subscription.Subscription{ - Channel: channels[x], - Pairs: currency.Pairs{pairs[z]}, - Asset: asset.Spot, - }) - } - } - } - return subscriptions, nil +// generateSubscriptions generates default subscription +func (by *Bybit) generateSubscriptions() (subscription.List, error) { + return by.Features.Subscriptions.ExpandTemplates(by) +} + +// GetSubscriptionTemplate returns a subscription channel template +func (by *Bybit) GetSubscriptionTemplate(_ *subscription.Subscription) (*template.Template, error) { + return template.New("master.tmpl").Funcs(template.FuncMap{ + "channelName": channelName, + "isSymbolChannel": isSymbolChannel, + "intervalToString": intervalToString, + }).Parse(subTplText) } // wsReadData receives and passes on websocket messages for processing @@ -788,3 +742,39 @@ func (by *Bybit) wsProcessOrderbook(assetType asset.Item, resp *WebsocketRespons } return nil } + +// channelName converts global channel names to exchange specific names +func channelName(s *subscription.Subscription) string { + if name, ok := subscriptionNames[s.Channel]; ok { + return name + } + return s.Channel +} + +// isSymbolChannel returns whether the channel accepts a symbol parameter +func isSymbolChannel(name string) bool { + switch name { + case chanPositions, chanExecution, chanOrder, chanDCP, chanWallet: + return false + } + return true +} + +const subTplText = ` +{{ with $name := channelName $.S }} + {{- range $asset, $pairs := $.AssetPairs }} + {{- if isSymbolChannel $name }} + {{- range $p := $pairs }} + {{- $name -}} . + {{- if eq $name "orderbook" -}} {{- $.S.Levels -}} . {{- end }} + {{- if eq $name "kline" -}} {{- intervalToString $.S.Interval -}} . {{- end }} + {{- $p }} + {{- $.PairSeparator }} + {{- end }} + {{- else }} + {{- $name }} + {{- end }} + {{- end }} + {{- $.AssetSeparator }} +{{- end }} +` diff --git a/exchanges/bybit/bybit_wrapper.go b/exchanges/bybit/bybit_wrapper.go index 9e091ba551a..662f1e900b1 100644 --- a/exchanges/bybit/bybit_wrapper.go +++ b/exchanges/bybit/bybit_wrapper.go @@ -183,6 +183,7 @@ func (by *Bybit) SetDefaults() { GlobalResultLimit: 1000, }, }, + Subscriptions: defaultSubscriptions.Clone(), } by.API.Endpoints = by.NewEndpoints() @@ -241,7 +242,7 @@ func (by *Bybit) Setup(exch *config.Exchange) error { Connector: by.WsConnect, Subscriber: by.Subscribe, Unsubscriber: by.Unsubscribe, - GenerateSubscriptions: by.GenerateDefaultSubscriptions, + GenerateSubscriptions: by.generateSubscriptions, Features: &by.Features.Supports.WebsocketCapabilities, OrderbookBufferConfig: buffer.Config{ SortBuffer: true, diff --git a/exchanges/orderbook/depth.go b/exchanges/orderbook/depth.go index a8e9ebda05e..d97b8cea282 100644 --- a/exchanges/orderbook/depth.go +++ b/exchanges/orderbook/depth.go @@ -46,7 +46,7 @@ type Depth struct { // validationError defines current book state and why it was invalidated. validationError error - m sync.Mutex + m sync.RWMutex } // NewDepth returns a new orderbook depth @@ -64,8 +64,8 @@ func (d *Depth) Publish() { // Retrieve returns the orderbook base a copy of the underlying linked list // spread func (d *Depth) Retrieve() (*Base, error) { - d.m.Lock() - defer d.m.Unlock() + d.m.RLock() + defer d.m.RUnlock() if d.validationError != nil { return nil, d.validationError } @@ -81,7 +81,7 @@ func (d *Depth) Retrieve() (*Base, error) { LastUpdateID: d.lastUpdateID, PriceDuplication: d.priceDuplication, IsFundingRate: d.isFundingRate, - VerifyOrderbook: d.VerifyOrderbook, + VerifyOrderbook: d.verifyOrderbook, MaxDepth: d.maxDepth, ChecksumStringRequired: d.checksumStringRequired, }, nil @@ -136,10 +136,9 @@ func (d *Depth) Invalidate(withReason error) error { // IsValid returns if the underlying book is valid. func (d *Depth) IsValid() bool { - d.m.Lock() - valid := d.validationError == nil - d.m.Unlock() - return valid + d.m.RLock() + defer d.m.RUnlock() + return d.validationError == nil } // UpdateBidAskByPrice updates the bid and ask spread by supplied updates, this @@ -285,7 +284,7 @@ func (d *Depth) AssignOptions(b *Base) { lastUpdateID: b.LastUpdateID, priceDuplication: b.PriceDuplication, isFundingRate: b.IsFundingRate, - VerifyOrderbook: b.VerifyOrderbook, + verifyOrderbook: b.VerifyOrderbook, restSnapshot: b.RestSnapshot, idAligned: b.IDAlignment, maxDepth: b.MaxDepth, @@ -296,15 +295,15 @@ func (d *Depth) AssignOptions(b *Base) { // GetName returns name of exchange func (d *Depth) GetName() string { - d.m.Lock() - defer d.m.Unlock() + d.m.RLock() + defer d.m.RUnlock() return d.exchange } // IsRESTSnapshot returns if the depth was updated via REST func (d *Depth) IsRESTSnapshot() (bool, error) { - d.m.Lock() - defer d.m.Unlock() + d.m.RLock() + defer d.m.RUnlock() if d.validationError != nil { return false, d.validationError } @@ -313,8 +312,8 @@ func (d *Depth) IsRESTSnapshot() (bool, error) { // LastUpdateID returns the last Update ID func (d *Depth) LastUpdateID() (int64, error) { - d.m.Lock() - defer d.m.Unlock() + d.m.RLock() + defer d.m.RUnlock() if d.validationError != nil { return 0, d.validationError } @@ -323,15 +322,22 @@ func (d *Depth) LastUpdateID() (int64, error) { // IsFundingRate returns if the depth is a funding rate func (d *Depth) IsFundingRate() bool { - d.m.Lock() - defer d.m.Unlock() + d.m.RLock() + defer d.m.RUnlock() return d.isFundingRate } +// VerifyOrderbook returns if the verify orderbook option is set +func (d *Depth) VerifyOrderbook() bool { + d.m.RLock() + defer d.m.RUnlock() + return d.verifyOrderbook +} + // GetAskLength returns length of asks func (d *Depth) GetAskLength() (int, error) { - d.m.Lock() - defer d.m.Unlock() + d.m.RLock() + defer d.m.RUnlock() if d.validationError != nil { return 0, d.validationError } @@ -340,8 +346,8 @@ func (d *Depth) GetAskLength() (int, error) { // GetBidLength returns length of bids func (d *Depth) GetBidLength() (int, error) { - d.m.Lock() - defer d.m.Unlock() + d.m.RLock() + defer d.m.RUnlock() if d.validationError != nil { return 0, d.validationError } @@ -351,8 +357,8 @@ func (d *Depth) GetBidLength() (int, error) { // TotalBidAmounts returns the total amount of bids and the total orderbook // bids value func (d *Depth) TotalBidAmounts() (liquidity, value float64, err error) { - d.m.Lock() - defer d.m.Unlock() + d.m.RLock() + defer d.m.RUnlock() if d.validationError != nil { return 0, 0, d.validationError } @@ -363,8 +369,8 @@ func (d *Depth) TotalBidAmounts() (liquidity, value float64, err error) { // TotalAskAmounts returns the total amount of asks and the total orderbook // asks value func (d *Depth) TotalAskAmounts() (liquidity, value float64, err error) { - d.m.Lock() - defer d.m.Unlock() + d.m.RLock() + defer d.m.RUnlock() if d.validationError != nil { return 0, 0, d.validationError } @@ -666,8 +672,8 @@ func (d *Depth) LiftTheAsksFromBest(amount float64, purchase bool) (*Movement, e // GetMidPrice returns the mid price between the ask and bid spread func (d *Depth) GetMidPrice() (float64, error) { - d.m.Lock() - defer d.m.Unlock() + d.m.RLock() + defer d.m.RUnlock() if d.validationError != nil { return 0, d.validationError } @@ -689,8 +695,8 @@ func (d *Depth) getMidPriceNoLock() (float64, error) { // GetBestBid returns the best bid price func (d *Depth) GetBestBid() (float64, error) { - d.m.Lock() - defer d.m.Unlock() + d.m.RLock() + defer d.m.RUnlock() if d.validationError != nil { return 0, d.validationError } @@ -699,8 +705,8 @@ func (d *Depth) GetBestBid() (float64, error) { // GetBestAsk returns the best ask price func (d *Depth) GetBestAsk() (float64, error) { - d.m.Lock() - defer d.m.Unlock() + d.m.RLock() + defer d.m.RUnlock() if d.validationError != nil { return 0, d.validationError } @@ -709,8 +715,8 @@ func (d *Depth) GetBestAsk() (float64, error) { // GetSpreadAmount returns the spread as a quotation amount func (d *Depth) GetSpreadAmount() (float64, error) { - d.m.Lock() - defer d.m.Unlock() + d.m.RLock() + defer d.m.RUnlock() if d.validationError != nil { return 0, d.validationError } @@ -727,8 +733,8 @@ func (d *Depth) GetSpreadAmount() (float64, error) { // GetSpreadPercentage returns the spread as a percentage func (d *Depth) GetSpreadPercentage() (float64, error) { - d.m.Lock() - defer d.m.Unlock() + d.m.RLock() + defer d.m.RUnlock() if d.validationError != nil { return 0, d.validationError } @@ -745,8 +751,8 @@ func (d *Depth) GetSpreadPercentage() (float64, error) { // GetImbalance returns top orderbook imbalance func (d *Depth) GetImbalance() (float64, error) { - d.m.Lock() - defer d.m.Unlock() + d.m.RLock() + defer d.m.RUnlock() if d.validationError != nil { return 0, d.validationError } @@ -769,8 +775,8 @@ func (d *Depth) GetTranches(count int) (ask, bid []Tranche, err error) { if count < 0 { return nil, nil, errInvalidBookDepth } - d.m.Lock() - defer d.m.Unlock() + d.m.RLock() + defer d.m.RUnlock() if d.validationError != nil { return nil, nil, d.validationError } @@ -779,8 +785,8 @@ func (d *Depth) GetTranches(count int) (ask, bid []Tranche, err error) { // GetPair returns the pair associated with the depth func (d *Depth) GetPair() (currency.Pair, error) { - d.m.Lock() - defer d.m.Unlock() + d.m.RLock() + defer d.m.RUnlock() if d.pair.IsEmpty() { return currency.Pair{}, currency.ErrCurrencyPairEmpty } diff --git a/exchanges/orderbook/depth_test.go b/exchanges/orderbook/depth_test.go index f78f1973056..da96ef1b558 100644 --- a/exchanges/orderbook/depth_test.go +++ b/exchanges/orderbook/depth_test.go @@ -77,7 +77,7 @@ func TestRetrieve(t *testing.T) { lastUpdateID: 1337, priceDuplication: true, isFundingRate: true, - VerifyOrderbook: true, + verifyOrderbook: true, restSnapshot: true, idAligned: true, maxDepth: 10, @@ -441,8 +441,8 @@ func TestAssignOptions(t *testing.T) { assert.Equal(t, tn, d.lastUpdated, "lastUpdated should be correct") assert.EqualValues(t, 1337, d.lastUpdateID, "lastUpdatedID should be correct") assert.True(t, d.priceDuplication, "priceDuplication should be correct") - assert.True(t, d.isFundingRate, "isFundingRate should be correct") - assert.True(t, d.VerifyOrderbook, "VerifyOrderbook should be correct") + assert.True(t, d.IsFundingRate(), "IsFundingRate should be correct") + assert.True(t, d.VerifyOrderbook(), "VerifyOrderbook should be correct") assert.True(t, d.restSnapshot, "restSnapshot should be correct") assert.True(t, d.idAligned, "idAligned should be correct") } diff --git a/exchanges/orderbook/orderbook_types.go b/exchanges/orderbook/orderbook_types.go index c40ddf6d96d..dc1929544a3 100644 --- a/exchanges/orderbook/orderbook_types.go +++ b/exchanges/orderbook/orderbook_types.go @@ -139,7 +139,7 @@ type options struct { lastUpdateID int64 priceDuplication bool isFundingRate bool - VerifyOrderbook bool + verifyOrderbook bool restSnapshot bool idAligned bool checksumStringRequired bool diff --git a/exchanges/stream/buffer/buffer.go b/exchanges/stream/buffer/buffer.go index d2ce8a9c1ac..27d707b5b5a 100644 --- a/exchanges/stream/buffer/buffer.go +++ b/exchanges/stream/buffer/buffer.go @@ -171,7 +171,7 @@ func (w *Orderbook) Update(u *orderbook.Update) error { } var ret *orderbook.Base - if book.ob.VerifyOrderbook { + if book.ob.VerifyOrderbook() { // This is used here so as to not retrieve book if verification is off. // On every update, this will retrieve and verify orderbook depth. ret, err = book.ob.Retrieve() @@ -333,17 +333,12 @@ func (w *Orderbook) LoadSnapshot(book *orderbook.Base) error { holder.updateID = book.LastUpdateID - err = holder.ob.LoadSnapshot(book.Bids, - book.Asks, - book.LastUpdateID, - book.LastUpdated, - book.UpdatePushedAt, - false) + err = holder.ob.LoadSnapshot(book.Bids, book.Asks, book.LastUpdateID, book.LastUpdated, book.UpdatePushedAt, false) if err != nil { return err } - if holder.ob.VerifyOrderbook { + if holder.ob.VerifyOrderbook() { // This is used here so as to not retrieve book if verification is off. // Checks to see if orderbook snapshot that was deployed has not been // altered in any way diff --git a/exchanges/subscription/subscription.go b/exchanges/subscription/subscription.go index 31e0493da81..a2a47055542 100644 --- a/exchanges/subscription/subscription.go +++ b/exchanges/subscription/subscription.go @@ -31,6 +31,7 @@ const ( AllTradesChannel = "allTrades" MyTradesChannel = "myTrades" MyOrdersChannel = "myOrders" + MyWalletChannel = "myWallet" ) // Public errors