Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Follow HTTP redirects after failed WS dials #251

Merged
merged 7 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion client/wsclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,28 @@
c.common.Callbacks.OnConnectFailed(ctx, err)
}
if resp != nil {
c.common.Logger.Errorf(ctx, "Server responded with status=%v", resp.Status)
duration := sharedinternal.ExtractRetryAfterHeader(resp)
if resp.StatusCode >= 300 && resp.StatusCode < 400 {
// very liberal handling of 3xx that largely ignores HTTP semantics
redirect, err := resp.Location()
if err != nil {
c.common.Callbacks.OnConnectFailed(ctx, err)
tigrannajaryan marked this conversation as resolved.
Show resolved Hide resolved
c.common.Logger.Errorf(ctx, "%d redirect, but no valid location: %s", resp.StatusCode, err)
return err, duration
}
// rewrite the scheme for the sake of tolerance
if redirect.Scheme == "http" {
redirect.Scheme = "ws"
} else if redirect.Scheme == "https" {
redirect.Scheme = "wss"

Check warning on line 147 in client/wsclient.go

View check run for this annotation

Codecov / codecov/patch

client/wsclient.go#L147

Added line #L147 was not covered by tests
}
c.common.Logger.Debugf(ctx, "%d redirect to %s", resp.StatusCode, redirect)
// Set the URL to the redirect, so that it connects to it on the
// next cycle.
c.url = redirect
echlebek marked this conversation as resolved.
Show resolved Hide resolved
} else {
c.common.Logger.Errorf(ctx, "Server responded with status=%v", resp.Status)
}
return err, duration
}
return err, sharedinternal.OptionalDuration{Defined: false}
Expand Down
83 changes: 83 additions & 0 deletions client/wsclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package client
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -177,3 +180,83 @@ func TestVerifyWSCompress(t *testing.T) {
})
}
}

func redirectServer(to string, status int) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
http.Redirect(w, req, to, http.StatusSeeOther)
}))
}

func errServer() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(302)
}))
}

func TestRedirectWS(t *testing.T) {
tigrannajaryan marked this conversation as resolved.
Show resolved Hide resolved
redirectee := internal.StartMockServer(t)
tests := []struct {
Name string
Redirector *httptest.Server
ExpError bool
}{
{
Name: "redirect ws scheme",
Redirector: redirectServer("ws://"+redirectee.Endpoint, 302),
},
{
Name: "redirect http scheme",
Redirector: redirectServer("http://"+redirectee.Endpoint, 302),
},
{
Name: "missing location header",
Redirector: errServer(),
ExpError: true,
},
}

for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
var conn atomic.Value
redirectee.OnWSConnect = func(c *websocket.Conn) {
conn.Store(c)
}

// Start an OpAMP/WebSocket client.
var connected int64
var connectErr atomic.Value
settings := types.StartSettings{
Callbacks: types.CallbacksStruct{
OnConnectFunc: func(ctx context.Context) {
atomic.StoreInt64(&connected, 1)
},
OnConnectFailedFunc: func(ctx context.Context, err error) {
if err != websocket.ErrBadHandshake {
connectErr.Store(err)
}
},
},
}
reURL, err := url.Parse(test.Redirector.URL)
assert.NoError(t, err)
reURL.Scheme = "ws"
settings.OpAMPServerURL = reURL.String()
client := NewWebSocket(nil)
startClient(t, settings, client)

// Wait for connection to be established.
eventually(t, func() bool { return conn.Load() != nil || connectErr.Load() != nil })
if test.ExpError {
if connectErr.Load() == nil {
t.Error("expected non-nil error")
}
} else {
assert.True(t, connectErr.Load() == nil)
}

// Stop the client.
err = client.Stop(context.Background())
assert.NoError(t, err)
})
}
}
Loading