diff --git a/client/wsclient.go b/client/wsclient.go index 44c4c92a..c2e9237c 100644 --- a/client/wsclient.go +++ b/client/wsclient.go @@ -141,7 +141,8 @@ func (c *wsClient) tryConnectOnce(ctx context.Context) (err error, retryAfter sh c.common.Logger.Errorf(ctx, "%d redirect, but no valid location: %s", resp.StatusCode, err) return err, duration } - if redirect.Scheme == "http" || redirect.Scheme == "" { + // rewrite the scheme for the sake of tolerance + if redirect.Scheme == "http" { redirect.Scheme = "ws" } else if redirect.Scheme == "https" { redirect.Scheme = "wss" diff --git a/client/wsclient_test.go b/client/wsclient_test.go index 70aea811..4e87a3b0 100644 --- a/client/wsclient_test.go +++ b/client/wsclient_test.go @@ -189,41 +189,56 @@ func redirectServer(to string) *httptest.Server { func TestRedirectWS(t *testing.T) { redirectee := internal.StartMockServer(t) - redirector := redirectServer("ws://" + redirectee.Endpoint) - defer redirector.Close() - - var conn atomic.Value - redirectee.OnWSConnect = func(c *websocket.Conn) { - conn.Store(c) - } - - // Start an OpAMP/WebSocket client. - var connected int64 - var connectErr atomic.Value - settings := types.StartSettings{ - Callbacks: types.CallbacksStruct{ - OnConnectFunc: func(ctx context.Context) { - atomic.StoreInt64(&connected, 1) - }, - OnConnectFailedFunc: func(ctx context.Context, err error) { - if err != websocket.ErrBadHandshake { - connectErr.Store(err) - } - }, + tests := []struct { + Name string + Redirector *httptest.Server + }{ + { + Name: "redirect ws scheme", + Redirector: redirectServer("ws://" + redirectee.Endpoint), + }, + { + Name: "redirect http scheme", + Redirector: redirectServer("http://" + redirectee.Endpoint), }, } - reURL, err := url.Parse(redirector.URL) - assert.NoError(t, err) - reURL.Scheme = "ws" - settings.OpAMPServerURL = reURL.String() - client := NewWebSocket(nil) - startClient(t, settings, client) - // Wait for connection to be established. - eventually(t, func() bool { return conn.Load() != nil }) - assert.True(t, connectErr.Load() == nil) + for _, test := range tests { + t.Run(test.Name, func(t *testing.T) { + var conn atomic.Value + redirectee.OnWSConnect = func(c *websocket.Conn) { + conn.Store(c) + } - // Stop the client. - err = client.Stop(context.Background()) - assert.NoError(t, err) + // Start an OpAMP/WebSocket client. + var connected int64 + var connectErr atomic.Value + settings := types.StartSettings{ + Callbacks: types.CallbacksStruct{ + OnConnectFunc: func(ctx context.Context) { + atomic.StoreInt64(&connected, 1) + }, + OnConnectFailedFunc: func(ctx context.Context, err error) { + if err != websocket.ErrBadHandshake { + connectErr.Store(err) + } + }, + }, + } + reURL, err := url.Parse(test.Redirector.URL) + assert.NoError(t, err) + reURL.Scheme = "ws" + settings.OpAMPServerURL = reURL.String() + client := NewWebSocket(nil) + startClient(t, settings, client) + + // Wait for connection to be established. + eventually(t, func() bool { return conn.Load() != nil }) + assert.True(t, connectErr.Load() == nil) + + // Stop the client. + err = client.Stop(context.Background()) + assert.NoError(t, err) + }) + } }