Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

common/gateio/stream: add thread-safe counter and overide default GenerateMessageID with connection specific implementation #1615

Merged
merged 8 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"unicode"

Expand Down Expand Up @@ -672,3 +673,19 @@ func SortStrings[S ~[]E, E fmt.Stringer](x S) S {
})
return n
}

// Counter is a thread-safe counter.
type Counter struct {
n int64 // privatised so you can't use counter as a value type
}

// IncrementAndGet returns the next count after incrementing.
func (c *Counter) IncrementAndGet() int64 {
newID := atomic.AddInt64(&c.n, 1)
// Handle overflow by resetting the counter to 1 if it becomes negative
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had thoughts of suggesting uint64 instead, but there are many unknowns and I doubt anyone is going to subscribe to WIF more 9,223,372,036,854,775,807 times on GateIO in a single session

if newID < 0 {
atomic.StoreInt64(&c.n, 1)
return 1
}
return newID
}
15 changes: 15 additions & 0 deletions common/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -862,3 +862,18 @@ func (a A) String() string {
func TestSortStrings(t *testing.T) {
assert.Equal(t, []A{1, 2, 5, 6}, SortStrings([]A{6, 2, 5, 1}))
}

func TestCounter(t *testing.T) {
t.Parallel()
c := Counter{n: -5}
require.Equal(t, int64(1), c.IncrementAndGet())
require.Equal(t, int64(2), c.IncrementAndGet())
}

// 683185328 1.787 ns/op 0 B/op 0 allocs/op
func BenchmarkCounter(b *testing.B) {
c := Counter{}
for i := 0; i < b.N; i++ {
c.IncrementAndGet()
}
}
1 change: 1 addition & 0 deletions exchanges/gateio/gateio.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ var (
// Gateio is the overarching type across this package
type Gateio struct {
exchange.Base
Counter common.Counter
}

// ***************************************** SubAccounts ********************************
Expand Down
5 changes: 5 additions & 0 deletions exchanges/gateio/gateio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3606,3 +3606,8 @@ func TestGetUnifiedAccount(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, payload)
}

func TestGenerateWebsocketMessageID(t *testing.T) {
t.Parallel()
require.NotEmpty(t, g.GenerateWebsocketMessageID(false))
}
6 changes: 6 additions & 0 deletions exchanges/gateio/gateio_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -868,3 +868,9 @@ func (g *Gateio) listOfAssetsCurrencyPairEnabledFor(cp currency.Pair) map[asset.
}
return assetPairEnabled
}

// GenerateWebsocketMessageID generates a message ID for the individual
// connection.
func (g *Gateio) GenerateWebsocketMessageID(bool) int64 {
return g.Counter.IncrementAndGet()
}
9 changes: 5 additions & 4 deletions exchanges/gateio/gateio_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,11 @@ func (g *Gateio) Setup(exch *config.Exchange) error {
return err
}
return g.Websocket.SetupNewConnection(stream.ConnectionSetup{
URL: gateioWebsocketEndpoint,
RateLimit: gateioWebsocketRateLimit,
ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout,
ResponseMaxLimit: exch.WebsocketResponseMaxLimit,
URL: gateioWebsocketEndpoint,
RateLimit: gateioWebsocketRateLimit,
ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout,
ResponseMaxLimit: exch.WebsocketResponseMaxLimit,
BespokeGenerateMessageID: g.GenerateWebsocketMessageID,
})
}

Expand Down
8 changes: 8 additions & 0 deletions exchanges/stream/stream_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ type Connection interface {
ReadMessage() Response
SendJSONMessage(any) error
SetupPingHandler(PingHandler)
// GenerateMessageID generates a message ID for the individual connection.
// If a bespoke function is set (by using SetupNewConnection) it will use
// that, otherwise it will use the defaultGenerateMessageID function defined
// in websocket_connection.go.
GenerateMessageID(highPrecision bool) int64
SendMessageReturnResponse(ctx context.Context, signature any, request any) ([]byte, error)
SendMessageReturnResponses(ctx context.Context, signature any, request any, expected int) ([][]byte, error)
Expand All @@ -41,6 +45,10 @@ type ConnectionSetup struct {
URL string
Authenticated bool
ConnectionLevelReporter Reporter
// BespokeGenerateMessageID is a function that returns a unique message ID.
// This is useful for when an exchange connection requires a unique or
// structured message ID for each message sent.
BespokeGenerateMessageID func(highPrecision bool) int64
}

// PingHandler container for ping handler settings
Expand Down
33 changes: 20 additions & 13 deletions exchanges/stream/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,13 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error {
if w == nil {
return fmt.Errorf("%w: %w", errConnSetup, errWebsocketIsNil)
}
if c == (ConnectionSetup{}) {

if c.ResponseCheckTimeout == 0 &&
c.ResponseMaxLimit == 0 &&
c.RateLimit == 0 &&
c.URL == "" &&
c.ConnectionLevelReporter == nil &&
c.BespokeGenerateMessageID == nil {
return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigEmpty)
}

Expand Down Expand Up @@ -234,18 +240,19 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error {
}

newConn := &WebsocketConnection{
ExchangeName: w.exchangeName,
URL: connectionURL,
ProxyURL: w.GetProxyAddress(),
Verbose: w.verbose,
ResponseMaxLimit: c.ResponseMaxLimit,
Traffic: w.TrafficAlert,
readMessageErrors: w.ReadMessageErrors,
ShutdownC: w.ShutdownC,
Wg: &w.Wg,
Match: w.Match,
RateLimit: c.RateLimit,
Reporter: c.ConnectionLevelReporter,
ExchangeName: w.exchangeName,
URL: connectionURL,
ProxyURL: w.GetProxyAddress(),
Verbose: w.verbose,
ResponseMaxLimit: c.ResponseMaxLimit,
Traffic: w.TrafficAlert,
readMessageErrors: w.ReadMessageErrors,
ShutdownC: w.ShutdownC,
Wg: &w.Wg,
Match: w.Match,
RateLimit: c.RateLimit,
Reporter: c.ConnectionLevelReporter,
bespokeGenerateMessageID: c.BespokeGenerateMessageID,
}

if c.Authenticated {
Expand Down
12 changes: 11 additions & 1 deletion exchanges/stream/websocket_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,18 @@ func (w *WebsocketConnection) parseBinaryResponse(resp []byte) ([]byte, error) {
return standardMessage, reader.Close()
}

// GenerateMessageID Creates a random message ID
// GenerateMessageID generates a message ID for the individual connection.
// If a bespoke function is set (by using SetupNewConnection) it will use that,
// otherwise it will use the defaultGenerateMessageID function.
func (w *WebsocketConnection) GenerateMessageID(highPrec bool) int64 {
if w.bespokeGenerateMessageID != nil {
return w.bespokeGenerateMessageID(highPrec)
}
return w.defaultGenerateMessageID(highPrec)
}

// defaultGenerateMessageID generates the default message ID
func (w *WebsocketConnection) defaultGenerateMessageID(highPrec bool) int64 {
var minValue int64 = 1e8
var maxValue int64 = 2e8
if highPrec {
Expand Down
3 changes: 3 additions & 0 deletions exchanges/stream/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,9 @@ func TestGenerateMessageID(t *testing.T) {
assert.NotContains(t, ids, id, "GenerateMessageID must not generate the same ID twice")
ids[i] = id
}

wc.bespokeGenerateMessageID = func(bool) int64 { return 42 }
assert.EqualValues(t, 42, wc.GenerateMessageID(true), "GenerateMessageID must use bespokeGenerateMessageID")
}

// BenchmarkGenerateMessageID-8 2850018 408 ns/op 56 B/op 4 allocs/op
Expand Down
5 changes: 5 additions & 0 deletions exchanges/stream/websocket_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,5 +145,10 @@ type WebsocketConnection struct {
Traffic chan struct{}
readMessageErrors chan error

// bespokeGenerateMessageID is a function that returns a unique message ID
// defined externally. This is used for exchanges that require a unique
// message ID for each message sent.
bespokeGenerateMessageID func(highPrecision bool) int64

Reporter Reporter
}
Loading