diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index dc65232416b..2ccbb3e2b87 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -499,7 +499,7 @@ func (w *Websocket) FlushChannels() error { } if !w.IsConnected() { - return fmt.Errorf("%s websocket: service not connected", w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, ErrNotConnected) } if w.features.Subscribe { diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 3e0d032bbfa..16798d0534f 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net" "net/http" "sort" @@ -1114,25 +1115,22 @@ func TestFlushChannels(t *testing.T) { dodgyWs := Websocket{} err := dodgyWs.FlushChannels() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, io.EOF, "FlushChannels should error correctly") dodgyWs.setEnabled(true) err = dodgyWs.FlushChannels() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, io.EOF, "FlushChannels should error correctly") web := Websocket{ - enabled: true, - connected: true, - connector: connect, - ShutdownC: make(chan struct{}), - Subscriber: newgen.SUBME, - Unsubscriber: newgen.UNSUBME, - Wg: new(sync.WaitGroup), - features: &protocol.Features{ + enabled: true, + connected: true, + connector: connect, + subscriptions: subscription.NewStore(), + ShutdownC: make(chan struct{}), + Subscriber: newgen.SUBME, + Unsubscriber: newgen.UNSUBME, + Wg: new(sync.WaitGroup), + features: &protocol.Features{ // No features }, trafficTimeout: time.Second * 30, // Added for when we utilise connect() @@ -1155,22 +1153,16 @@ func TestFlushChannels(t *testing.T) { return subscription.List{{Channel: "test"}}, nil } err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Flush Channels must not error") web.features.FullPayloadSubscribe = true web.GenerateSubs = problemFunc err = web.FlushChannels() // error on full subscribeToChannels - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, io.EOF, "Flush Channels should error correctly") web.GenerateSubs = noSub err = web.FlushChannels() // No subs to sub - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "Flush Channels should not error") web.GenerateSubs = newgen.generateSubs subs, err := web.GenerateSubs()