diff --git a/exchanges/kraken/kraken_test.go b/exchanges/kraken/kraken_test.go index e15934b3a98..654c0bb47ef 100644 --- a/exchanges/kraken/kraken_test.go +++ b/exchanges/kraken/kraken_test.go @@ -1020,6 +1020,35 @@ func TestWsSubscribe(t *testing.T) { } } +// TestWsResubscribe tests websocket resubscription +func TestWsResubscribe(t *testing.T) { + k := new(Kraken) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + require.NoError(t, testexch.Setup(k), "TestInstance must not error") + testexch.SetupWs(t, k) + + err := k.Subscribe(subscription.List{{Asset: asset.Spot, Channel: subscription.OrderbookChannel, Levels: 1000}}) + require.NoError(t, err, "Subscribe must not error") + subs := k.Websocket.GetSubscriptions() + require.Len(t, subs, 1, "Should add 1 Subscription") + require.Equal(t, subscription.SubscribedState, subs[0].State(), "Subscription should be subscribed state") + + require.Eventually(t, func() bool { + b, err := k.Websocket.Orderbook.GetOrderbook(xbtusdPair, asset.Spot) + if err == nil { + return !b.LastUpdated.IsZero() + } + return false + }, time.Second*4, time.Millisecond*10, "orderbook must start streaming") + + // Set the state to Unsub so we definitely know Resub worked + err = subs[0].SetState(subscription.UnsubscribingState) + require.NoError(t, err) + + err = k.Websocket.ResubscribeToChannel(subs[0]) + require.NoError(t, err, "Resubscribe must not error") + require.Equal(t, subscription.SubscribedState, subs[0].State(), "subscription must be subscribed again") +} + // TestWsOrderbookSub tests orderbook subscriptions for MaxDepth params func TestWsOrderbookSub(t *testing.T) { t.Parallel() diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index b310ca5bb05..edcdb653a09 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -957,6 +957,9 @@ func (w *Websocket) checkSubscriptions(subs subscription.List) error { } for _, s := range subs { + if s.State() == subscription.ResubscribingState { + continue + } if found := w.subscriptions.Get(s); found != nil { return fmt.Errorf("%w: %s", subscription.ErrDuplicate, s) }