From 1fe7486119e3f892453c2ddb68dee6ca0a67637a Mon Sep 17 00:00:00 2001 From: Jean Pierre Date: Wed, 10 Aug 2022 21:03:38 +0000 Subject: [PATCH] Fix X-Forwarded-* headers --- components/openvsx-proxy/pkg/handler.go | 2 - .../openvsx-proxy/pkg/modifyresponse.go | 51 +---------- .../openvsx-proxy/pkg/openvsxproxy_test.go | 85 ------------------- components/openvsx-proxy/pkg/run.go | 79 ++++++++++++++++- 4 files changed, 82 insertions(+), 135 deletions(-) diff --git a/components/openvsx-proxy/pkg/handler.go b/components/openvsx-proxy/pkg/handler.go index 2ddcc486dabdcc..36f2ca3b645124 100644 --- a/components/openvsx-proxy/pkg/handler.go +++ b/components/openvsx-proxy/pkg/handler.go @@ -49,7 +49,6 @@ func (o *OpenVSXProxy) Handler(p *httputil.ReverseProxy) func(http.ResponseWrite key, err := o.key(r) if err != nil { log.WithFields(logFields).WithError(err).Error("cannot create cache key") - r.Host = o.upstreamURL.Host p.ServeHTTP(rw, r) o.finishLog(logFields, start, hitCacheRegular, hitCacheBackup) o.metrics.DurationRequestProcessingHistogram.Observe(time.Since(start).Seconds()) @@ -105,7 +104,6 @@ func (o *OpenVSXProxy) Handler(p *httputil.ReverseProxy) func(http.ResponseWrite log.WithFields(logFields).WithFields(o.DurationLogFields(duration)).Info("processing request finished") o.metrics.DurationRequestProcessingHistogram.Observe(duration.Seconds()) - r.Host = o.upstreamURL.Host p.ServeHTTP(rw, r) o.finishLog(logFields, start, hitCacheRegular, hitCacheBackup) } diff --git a/components/openvsx-proxy/pkg/modifyresponse.go b/components/openvsx-proxy/pkg/modifyresponse.go index d663d7bb279581..35908b4cca9b3c 100644 --- a/components/openvsx-proxy/pkg/modifyresponse.go +++ b/components/openvsx-proxy/pkg/modifyresponse.go @@ -6,12 +6,10 @@ package pkg import ( "bytes" - "compress/gzip" "fmt" "io/ioutil" "net/http" "strconv" - "strings" "time" "unicode/utf8" @@ -55,7 +53,6 @@ func (o *OpenVSXProxy) ModifyResponse(r *http.Response) error { return err } r.Body.Close() - r.Body = ioutil.NopCloser(bytes.NewBuffer(rawBody)) if r.StatusCode >= 500 || r.StatusCode == http.StatusTooManyRequests || r.StatusCode == http.StatusRequestTimeout { // use cache if exists @@ -94,49 +91,9 @@ func (o *OpenVSXProxy) ModifyResponse(r *http.Response) error { } // no error (status code < 500) - body := rawBody - contentType := r.Header.Get("Content-Type") - if strings.HasPrefix(contentType, "application/json") { - isCompressedResponse := strings.EqualFold(r.Header.Get("Content-Encoding"), "gzip") - if isCompressedResponse { - gzipReader, err := gzip.NewReader(ioutil.NopCloser(bytes.NewBuffer(rawBody))) - if err != nil { - log.WithFields(logFields).WithError(err) - return nil - } - - body, err = ioutil.ReadAll(gzipReader) - if err != nil { - log.WithFields(logFields).WithError(err).Error("error reading compressed response body") - return nil - } - gzipReader.Close() - } - - if log.Log.Level >= logrus.DebugLevel { - log.WithFields(logFields).Debugf("replacing %d occurence(s) of '%s' in response body ...", strings.Count(string(body), o.Config.URLUpstream), o.Config.URLUpstream) - } - bodyStr := strings.ReplaceAll(string(body), o.Config.URLUpstream, o.Config.URLLocal) - body = []byte(bodyStr) - - if isCompressedResponse { - var b bytes.Buffer - gzipWriter := gzip.NewWriter(&b) - _, err = gzipWriter.Write(body) - if err != nil { - log.WithFields(logFields).WithError(err).Error("error writing compressed response body") - return nil - } - gzipWriter.Close() - body = b.Bytes() - } - } else { - log.WithFields(logFields).Debugf("response is not JSON but '%s', skipping replacing '%s' in response body", contentType, o.Config.URLUpstream) - } - cacheObj := &CacheObject{ Header: r.Header, - Body: body, + Body: rawBody, StatusCode: r.StatusCode, } err = o.StoreCache(key, cacheObj) @@ -146,8 +103,8 @@ func (o *OpenVSXProxy) ModifyResponse(r *http.Response) error { log.WithFields(logFields).Info("successfully stored response to cache") } - r.Body = ioutil.NopCloser(bytes.NewBuffer(body)) - r.ContentLength = int64(len(body)) - r.Header.Set("Content-Length", strconv.Itoa(len(body))) + r.Body = ioutil.NopCloser(bytes.NewBuffer(rawBody)) + r.ContentLength = int64(len(rawBody)) + r.Header.Set("Content-Length", strconv.Itoa(len(rawBody))) return nil } diff --git a/components/openvsx-proxy/pkg/openvsxproxy_test.go b/components/openvsx-proxy/pkg/openvsxproxy_test.go index 3c16a34eac3d13..03e0e10ec919c9 100644 --- a/components/openvsx-proxy/pkg/openvsxproxy_test.go +++ b/components/openvsx-proxy/pkg/openvsxproxy_test.go @@ -6,7 +6,6 @@ package pkg import ( "bytes" - "compress/gzip" "fmt" "io" "net/http" @@ -30,90 +29,6 @@ func createFrontend(backendURL string) (*httptest.Server, *OpenVSXProxy) { return frontend, openVSXProxy } -func TestReplaceHostInJSONResponse(t *testing.T) { - backend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - bodyBytes, _ := io.ReadAll(r.Body) - rw.Header().Set("Content-Type", "application/json") - rw.Write([]byte(fmt.Sprintf("Hello %s!", string(bodyBytes)))) - })) - defer backend.Close() - - frontend, _ := createFrontend(backend.URL) - defer frontend.Close() - - frontendClient := frontend.Client() - - requestBody := backend.URL - req, _ := http.NewRequest("POST", frontend.URL, bytes.NewBuffer([]byte(requestBody))) - req.Close = true - res, err := frontendClient.Do(req) - if err != nil { - t.Fatal(err) - } - expectedResponse := fmt.Sprintf("Hello %s!", frontend.URL) - if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expectedResponse { - t.Errorf("got body '%s'; expected '%s'", string(bodyBytes), expectedResponse) - } -} - -func TestReplaceHostInCompressedJSONResponse(t *testing.T) { - backend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - bodyBytes, _ := io.ReadAll(r.Body) - rw.Header().Set("Content-Type", "application/json") - rw.Header().Set("Content-Encoding", "gzip") - - var b bytes.Buffer - w := gzip.NewWriter(&b) - w.Write([]byte(fmt.Sprintf("Hello %s!", string(bodyBytes)))) - w.Close() - rw.Write(b.Bytes()) - })) - defer backend.Close() - - frontend, _ := createFrontend(backend.URL) - defer frontend.Close() - - frontendClient := frontend.Client() - - requestBody := backend.URL - req, _ := http.NewRequest("POST", frontend.URL, bytes.NewBuffer([]byte(requestBody))) - req.Close = true - res, err := frontendClient.Do(req) - if err != nil { - t.Fatal(err) - } - expectedResponse := fmt.Sprintf("Hello %s!", frontend.URL) - if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expectedResponse { - t.Errorf("got body '%s'; expected '%s'", string(bodyBytes), expectedResponse) - } -} - -func TestNotReplaceHostInNonJSONResponse(t *testing.T) { - backend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - bodyBytes, _ := io.ReadAll(r.Body) - rw.Header().Set("Content-Type", "application/octet-stream") - rw.Write([]byte(fmt.Sprintf("Hello %s!", string(bodyBytes)))) - })) - defer backend.Close() - - frontend, _ := createFrontend(backend.URL) - defer frontend.Close() - - frontendClient := frontend.Client() - - requestBody := backend.URL - req, _ := http.NewRequest("POST", frontend.URL, bytes.NewBuffer([]byte(requestBody))) - req.Close = true - res, err := frontendClient.Do(req) - if err != nil { - t.Fatal(err) - } - expectedResponse := fmt.Sprintf("Hello %s!", backend.URL) - if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expectedResponse { - t.Errorf("got body '%s'; expected '%s'", string(bodyBytes), expectedResponse) - } -} - func TestAddResponseToCache(t *testing.T) { backend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { bodyBytes, _ := io.ReadAll(r.Body) diff --git a/components/openvsx-proxy/pkg/run.go b/components/openvsx-proxy/pkg/run.go index 02dbacc0deebc7..09b1ae109f0991 100644 --- a/components/openvsx-proxy/pkg/run.go +++ b/components/openvsx-proxy/pkg/run.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httputil" "net/url" + "strings" "time" "github.com/eko/gocache/cache" @@ -58,7 +59,7 @@ func (o *OpenVSXProxy) Start() (shutdown func(context.Context) error, err error) return nil, err } } - proxy := httputil.NewSingleHostReverseProxy(o.upstreamURL) + proxy := newSingleHostReverseProxy(o.upstreamURL) proxy.ErrorHandler = o.ErrorHandler proxy.ModifyResponse = o.ModifyResponse proxy.Transport = &DurationTrackingTransport{o: o} @@ -112,3 +113,79 @@ func (t *DurationTrackingTransport) RoundTrip(r *http.Request) (*http.Response, }(start) return http.DefaultTransport.RoundTrip(r) } + +// From go/src/net/http/httputil/reverseproxy.go + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +func joinURLPath(a, b *url.URL) (path, rawpath string) { + if a.RawPath == "" && b.RawPath == "" { + return singleJoiningSlash(a.Path, b.Path), "" + } + // Same as singleJoiningSlash, but uses EscapedPath to determine + // whether a slash should be added + apath := a.EscapedPath() + bpath := b.EscapedPath() + + aslash := strings.HasSuffix(apath, "/") + bslash := strings.HasPrefix(bpath, "/") + + switch { + case aslash && bslash: + return a.Path + b.Path[1:], apath + bpath[1:] + case !aslash && !bslash: + return a.Path + "/" + b.Path, apath + "/" + bpath + } + return a.Path + b.Path, apath + bpath +} + +func newSingleHostReverseProxy(target *url.URL) *httputil.ReverseProxy { + targetQuery := target.RawQuery + director := func(req *http.Request) { + originalHost := req.Host + + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL) + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } + req.Host = target.Host + + if _, ok := req.Header["User-Agent"]; !ok { + // explicitly disable User-Agent so it's not set to default value + req.Header.Set("User-Agent", "") + } + + // From https://github.com/golang/go/pull/36678 + prior, ok := req.Header["X-Forwarded-Host"] + omit := ok && prior == nil // nil means don't populate the header + if !omit { + req.Header.Set("X-Forwarded-Host", originalHost) + } + + prior, ok = req.Header["X-Forwarded-Proto"] + omit = ok && prior == nil // nil means don't populate the header + if !omit { + if req.TLS == nil { + req.Header.Set("X-Forwarded-Proto", "http") + } else { + req.Header.Set("X-Forwarded-Proto", "https") + } + } + // ReverseProxy will add X-Forwarded-For internally + } + return &httputil.ReverseProxy{Director: director} +}