From 9fddd2ccfb6724ae4e883fae250b965ce7f7c2f4 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 14 Oct 2023 06:13:10 -0700 Subject: [PATCH] dial: Redirect wss/ws correctly by modifying the http client Closes #333 Needs a test. --- dial.go | 15 +++++++++++++++ dial_test.go | 6 ++++++ 2 files changed, 21 insertions(+) diff --git a/dial.go b/dial.go index e72432e7..e4c4daa1 100644 --- a/dial.go +++ b/dial.go @@ -70,6 +70,21 @@ func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context if o.HTTPHeader == nil { o.HTTPHeader = http.Header{} } + newClient := *o.HTTPClient + oldCheckRedirect := o.HTTPClient.CheckRedirect + newClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + switch req.URL.Scheme { + case "ws": + req.URL.Scheme = "http" + case "wss": + req.URL.Scheme = "https" + } + if oldCheckRedirect != nil { + return oldCheckRedirect(req, via) + } + return nil + } + o.HTTPClient = &newClient return ctx, cancel, &o } diff --git a/dial_test.go b/dial_test.go index e072db2d..641d8377 100644 --- a/dial_test.go +++ b/dial_test.go @@ -304,3 +304,9 @@ type roundTripperFunc func(*http.Request) (*http.Response, error) func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } + +func TestDialRedirect(t *testing.T) { + t.Parallel() + + // TODO: +}