Skip to content

Commit

Permalink
Websocket: Simplify Connecting/Connected state
Browse files Browse the repository at this point in the history
  • Loading branch information
gbjk committed Jan 30, 2024
1 parent e8395a5 commit 9acf5a3
Show file tree
Hide file tree
Showing 26 changed files with 83 additions and 95 deletions.
2 changes: 1 addition & 1 deletion engine/websocketroutine_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) {
t.Fatal("unexpected data handlers registered")
}

mock := stream.New()
mock := stream.NewWebsocket()
mock.ToRoutine = make(chan interface{})
m.state = readyState
err = m.websocketDataReceiver(mock)
Expand Down
2 changes: 1 addition & 1 deletion exchanges/binance/binance_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func (b *Binance) SetDefaults() {
log.Errorln(log.ExchangeSys, err)
}

b.Websocket = stream.New()
b.Websocket = stream.NewWebsocket()
b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
}
Expand Down
2 changes: 1 addition & 1 deletion exchanges/binanceus/binanceus_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func (bi *Binanceus) SetDefaults() {
"%s setting default endpoints error %v",
bi.Name, err)
}
bi.Websocket = stream.New()
bi.Websocket = stream.NewWebsocket()
bi.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
bi.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
bi.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
Expand Down
2 changes: 1 addition & 1 deletion exchanges/bitfinex/bitfinex_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func (b *Bitfinex) SetDefaults() {
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
b.Websocket = stream.New()
b.Websocket = stream.NewWebsocket()
b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
Expand Down
2 changes: 1 addition & 1 deletion exchanges/bithumb/bithumb_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func (b *Bithumb) SetDefaults() {
log.Errorln(log.ExchangeSys, err)
}

b.Websocket = stream.New()
b.Websocket = stream.NewWebsocket()
b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
}
Expand Down
2 changes: 1 addition & 1 deletion exchanges/bitmex/bitmex_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func (b *Bitmex) SetDefaults() {
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
b.Websocket = stream.New()
b.Websocket = stream.NewWebsocket()
b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
Expand Down
2 changes: 1 addition & 1 deletion exchanges/bitstamp/bitstamp_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func (b *Bitstamp) SetDefaults() {
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
b.Websocket = stream.New()
b.Websocket = stream.NewWebsocket()
b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
Expand Down
2 changes: 1 addition & 1 deletion exchanges/btcmarkets/btcmarkets_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func (b *BTCMarkets) SetDefaults() {
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
b.Websocket = stream.New()
b.Websocket = stream.NewWebsocket()
b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
Expand Down
2 changes: 1 addition & 1 deletion exchanges/btse/btse_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (b *BTSE) SetDefaults() {
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
b.Websocket = stream.New()
b.Websocket = stream.NewWebsocket()
b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
Expand Down
2 changes: 1 addition & 1 deletion exchanges/bybit/bybit_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ func (by *Bybit) SetDefaults() {
log.Errorln(log.ExchangeSys, err)
}

by.Websocket = stream.New()
by.Websocket = stream.NewWebsocket()
by.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
by.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
by.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
Expand Down
2 changes: 1 addition & 1 deletion exchanges/coinbasepro/coinbasepro_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func (c *CoinbasePro) SetDefaults() {
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
c.Websocket = stream.New()
c.Websocket = stream.NewWebsocket()
c.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
c.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
c.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
Expand Down
2 changes: 1 addition & 1 deletion exchanges/coinut/coinut_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (c *COINUT) SetDefaults() {
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
c.Websocket = stream.New()
c.Websocket = stream.NewWebsocket()
c.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
c.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
c.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
Expand Down
6 changes: 3 additions & 3 deletions exchanges/exchange_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ func TestSetClientProxyAddress(t *testing.T) {
Name: "rawr",
Requester: requester}

newBase.Websocket = stream.New()
newBase.Websocket = stream.NewWebsocket()
err = newBase.SetClientProxyAddress("")
if err != nil {
t.Error(err)
Expand Down Expand Up @@ -1251,7 +1251,7 @@ func TestSetupDefaults(t *testing.T) {
}

// Test websocket support
b.Websocket = stream.New()
b.Websocket = stream.NewWebsocket()
b.Features.Supports.Websocket = true
err = b.Websocket.Setup(&stream.WebsocketSetup{
ExchangeConfig: &config.Exchange{
Expand Down Expand Up @@ -1596,7 +1596,7 @@ func TestIsWebsocketEnabled(t *testing.T) {
t.Error("exchange doesn't support websocket")
}

b.Websocket = stream.New()
b.Websocket = stream.NewWebsocket()
err := b.Websocket.Setup(&stream.WebsocketSetup{
ExchangeConfig: &config.Exchange{
Enabled: true,
Expand Down
2 changes: 1 addition & 1 deletion exchanges/gateio/gateio_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func (g *Gateio) SetDefaults() {
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
g.Websocket = stream.New()
g.Websocket = stream.NewWebsocket()
g.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
g.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
g.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
Expand Down
2 changes: 1 addition & 1 deletion exchanges/gemini/gemini_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func (g *Gemini) SetDefaults() {
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
g.Websocket = stream.New()
g.Websocket = stream.NewWebsocket()
g.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
g.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
g.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
Expand Down
2 changes: 1 addition & 1 deletion exchanges/hitbtc/hitbtc_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func (h *HitBTC) SetDefaults() {
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
h.Websocket = stream.New()
h.Websocket = stream.NewWebsocket()
h.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
h.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
h.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
Expand Down
2 changes: 1 addition & 1 deletion exchanges/huobi/huobi_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ func (h *HUOBI) SetDefaults() {
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
h.Websocket = stream.New()
h.Websocket = stream.NewWebsocket()
h.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
h.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
h.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
Expand Down
2 changes: 1 addition & 1 deletion exchanges/kraken/kraken_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func (k *Kraken) SetDefaults() {
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
k.Websocket = stream.New()
k.Websocket = stream.NewWebsocket()
k.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
k.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
k.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
Expand Down
2 changes: 1 addition & 1 deletion exchanges/kucoin/kucoin_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func (ku *Kucoin) SetDefaults() {
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
ku.Websocket = stream.New()
ku.Websocket = stream.NewWebsocket()
ku.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
ku.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
ku.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
Expand Down
2 changes: 1 addition & 1 deletion exchanges/okcoin/okcoin_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func (o *Okcoin) SetDefaults() {
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
o.Websocket = stream.New()
o.Websocket = stream.NewWebsocket()
o.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
o.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
o.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
Expand Down
2 changes: 1 addition & 1 deletion exchanges/okx/okx_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func (ok *Okx) SetDefaults() {
log.Errorln(log.ExchangeSys, err)
}

ok.Websocket = stream.New()
ok.Websocket = stream.NewWebsocket()
ok.WebsocketResponseMaxLimit = okxWebsocketResponseMaxLimit
ok.WebsocketResponseCheckTimeout = okxWebsocketResponseMaxLimit
ok.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
Expand Down
2 changes: 1 addition & 1 deletion exchanges/poloniex/poloniex_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func (p *Poloniex) SetDefaults() {
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
p.Websocket = stream.New()
p.Websocket = stream.NewWebsocket()
p.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
p.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
p.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
Expand Down
1 change: 0 additions & 1 deletion exchanges/sharedtestvalues/sharedtestvalues.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ func GetWebsocketStructChannelOverride() chan struct{} {
// NewTestWebsocket returns a test websocket object
func NewTestWebsocket() *stream.Websocket {
return &stream.Websocket{
Init: true,
DataHandler: make(chan interface{}, WebsocketChannelOverrideCapacity),
ToRoutine: make(chan interface{}, 1000),
TrafficAlert: make(chan struct{}),
Expand Down
37 changes: 17 additions & 20 deletions exchanges/stream/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ var (
errChannelAlreadySubscribed = errors.New("channel already subscribed")
errInvalidChannelState = errors.New("invalid Channel state")
errSameProxyAddress = errors.New("cannot set proxy address to the same address")
errNoConnectFunc = errors.New("connect func not set")
errAlreadyConnected = errors.New("already connected")
)

var globalReporter Reporter
Expand All @@ -75,8 +77,8 @@ func SetupGlobalReporter(r Reporter) {
globalReporter = r
}

// New initialises the websocket struct
func New() *Websocket {
// NewWebsocket initialises the websocket struct
func NewWebsocket() *Websocket {
return &Websocket{
DataHandler: make(chan interface{}, defaultJobBuffer),
ToRoutine: make(chan interface{}, defaultJobBuffer),
Expand Down Expand Up @@ -188,6 +190,8 @@ func (w *Websocket) Setup(s *WebsocketSetup) error {
return fmt.Errorf("%s %w", w.exchangeName, errInvalidMaxSubscriptions)
}
w.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection
w.setState(disconnected)

return nil
}

Expand Down Expand Up @@ -253,7 +257,7 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error {
// function
func (w *Websocket) Connect() error {
if w.connector == nil {
return errors.New("websocket connect function not set, cannot continue")
return errNoConnectFunc
}
w.m.Lock()
defer w.m.Unlock()
Expand All @@ -270,15 +274,14 @@ func (w *Websocket) Connect() error {

w.dataMonitor()
w.trafficMonitor()
w.setConnectingStatus(true)
w.setState(connecting)

err := w.connector()
if err != nil {
w.setConnectingStatus(false)
w.setState(disconnected)
return fmt.Errorf("%v Error connecting %w", w.exchangeName, err)
}
w.setConnectedStatus(true)
w.setConnectingStatus(false)
w.setState(connected)

if !w.IsConnectionMonitorRunning() {
err = w.connectionMonitor()
Expand Down Expand Up @@ -306,6 +309,7 @@ func (w *Websocket) Connect() error {
}

// Disable disables the exchange websocket protocol
// Note that connectionMonitor will be responsible for shutting down the websocket after disabling
func (w *Websocket) Disable() error {
if !w.IsEnabled() {
return fmt.Errorf("%w for exchange '%s'", ErrAlreadyDisabled, w.exchangeName)
Expand Down Expand Up @@ -405,7 +409,7 @@ func (w *Websocket) connectionMonitor() error {
case err := <-w.ReadMessageErrors:
if IsDisconnectionError(err) {
log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err)
w.setConnectedStatus(false)
w.setState(disconnected)
}

w.DataHandler <- err
Expand Down Expand Up @@ -470,8 +474,7 @@ func (w *Websocket) Shutdown() error {
close(w.ShutdownC)
w.Wg.Wait()
w.ShutdownC = make(chan struct{})
w.setConnectedStatus(false)
w.setConnectingStatus(false)
w.setState(disconnected)
if w.verbose {
log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", w.exchangeName)
}
Expand Down Expand Up @@ -565,7 +568,7 @@ func (w *Websocket) trafficMonitor() {
default:
}
}
w.setConnectedStatus(true)
w.setState(connected)
trafficTimer.Reset(w.trafficTimeout)
case <-trafficTimer.C: // Falls through when timer runs out
if w.verbose {
Expand Down Expand Up @@ -606,9 +609,9 @@ func (w *Websocket) trafficMonitor() {
}()
}

func (w *Websocket) setConnectedStatus(b bool) {
func (w *Websocket) setState(s state) {
w.fieldMutex.Lock()
w.state = connected
w.state = s
w.fieldMutex.Unlock()
}

Expand All @@ -619,17 +622,11 @@ func (w *Websocket) IsConnected() bool {
return w.state == connected
}

func (w *Websocket) setConnectingStatus(b bool) {
w.fieldMutex.Lock()
w.state = connecting
w.fieldMutex.Unlock()
}

// IsConnecting returns status of connecting
func (w *Websocket) IsConnecting() bool {
w.fieldMutex.RLock()
defer w.fieldMutex.RUnlock()
return w.connecting
return w.state == connecting
}

func (w *Websocket) setEnabled(b bool) {
Expand Down
Loading

0 comments on commit 9acf5a3

Please sign in to comment.