Skip to content

Commit

Permalink
follow redirects
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanp413 committed Apr 5, 2022
1 parent 3b3f86d commit fcd932c
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 25 deletions.
61 changes: 37 additions & 24 deletions components/openvsx-proxy/pkg/modifyresponse.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bytes"
"compress/gzip"
"fmt"
"io"
"io/ioutil"
"net/http"
"strconv"
Expand Down Expand Up @@ -57,6 +58,40 @@ func (o *OpenVSXProxy) ModifyResponse(r *http.Response) error {
r.Body.Close()
r.Body = ioutil.NopCloser(bytes.NewBuffer(rawBody))

if r.StatusCode >= 300 && r.StatusCode < 400 && o.Config.FollowRedirects > 0 && r.Header.Get("Location") != "" {
var newBody io.ReadCloser
if r.Request.Body != nil {
// TODO: r.Request.Body seems to be always empty so for now it only works for GET requests
rawReqBody, err := ioutil.ReadAll(r.Request.Body)
if err != nil {
log.WithFields(logFields).WithError(err).Error("error reading request raw body")
return nil
}
newBody = ioutil.NopCloser(bytes.NewBuffer(rawReqBody))
}

redirectReq, err := http.NewRequestWithContext(r.Request.Context(), r.Request.Method, r.Header.Get("Location"), newBody)
if err != nil {
log.WithFields(logFields).WithError(err)
return nil
}

log.WithFields(logFields).Infof("following redirect request: %s", r.Header.Get("Location"))
r, err = o.client.Do(redirectReq)
if err != nil {
log.WithFields(logFields).WithError(err).Error("error following redirect request")
return nil
}

rawBody, err = ioutil.ReadAll(r.Body)
if err != nil {
log.WithFields(logFields).WithError(err).Error("error reading redirect response raw body")
return nil
}
r.Body.Close()
r.Body = ioutil.NopCloser(bytes.NewBuffer(rawBody))
}

if r.StatusCode >= 500 || r.StatusCode == http.StatusTooManyRequests || r.StatusCode == http.StatusRequestTimeout {
// use cache if exists
bodyLogField := "(binary)"
Expand Down Expand Up @@ -94,42 +129,20 @@ func (o *OpenVSXProxy) ModifyResponse(r *http.Response) error {
}

// no error (status code < 500)
if r.StatusCode >= 300 && r.StatusCode < 400 && o.Config.FollowRedirects > 0 && r.Header.Get("Location") != "" {
newBody, err := r.Request.GetBody()
if err != nil {
log.WithFields(logFields).WithError(err).Error("error getting body for redirect")
return err
}

redirectReq, err := http.NewRequestWithContext(r.Request.Context(), r.Request.Method, r.Header.Get("Location"), newBody)
r, err = o.client.Do(redirectReq)
if err != nil {
log.WithFields(logFields).WithError(err).Error("error doing redirect request")
return err
}

rawBody, err = ioutil.ReadAll(r.Body)
if err != nil {
log.WithFields(logFields).WithError(err).Error("error reading redirect response raw body")
return err
}
r.Body.Close()
r.Body = ioutil.NopCloser(bytes.NewBuffer(rawBody))
}

body := rawBody
contentType := r.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
if strings.EqualFold(r.Header.Get("Content-Encoding"), "gzip") {
gzipReader, err := gzip.NewReader(ioutil.NopCloser(bytes.NewBuffer(rawBody)))
if err != nil {
log.WithFields(logFields).WithError(err)
return nil
}

body, err = ioutil.ReadAll(gzipReader)
if err != nil {
log.WithFields(logFields).WithError(err).Error("error reading response body")
return err
return nil
}
gzipReader.Close()

Expand Down
66 changes: 65 additions & 1 deletion components/openvsx-proxy/pkg/openvsxproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package pkg

import (
"bytes"
"compress/gzip"
"fmt"
"io"
"net/http"
Expand All @@ -16,7 +17,8 @@ import (

func createFrontend(backendURL string) (*httptest.Server, *OpenVSXProxy) {
cfg := &Config{
URLUpstream: backendURL,
URLUpstream: backendURL,
FollowRedirects: 3,
}
openVSXProxy := &OpenVSXProxy{Config: cfg}
openVSXProxy.Setup()
Expand Down Expand Up @@ -55,6 +57,38 @@ func TestReplaceHostInJSONResponse(t *testing.T) {
}
}

func TestReplaceHostInCompressedJSONResponse(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
bodyBytes, _ := io.ReadAll(r.Body)
rw.Header().Set("Content-Type", "application/json")
rw.Header().Set("Content-Encoding", "gzip")

var b bytes.Buffer
w := gzip.NewWriter(&b)
w.Write([]byte(fmt.Sprintf("Hello %s!", string(bodyBytes))))
w.Close()
rw.Write(b.Bytes())
}))
defer backend.Close()

frontend, _ := createFrontend(backend.URL)
defer frontend.Close()

frontendClient := frontend.Client()

requestBody := backend.URL
req, _ := http.NewRequest("POST", frontend.URL, bytes.NewBuffer([]byte(requestBody)))
req.Close = true
res, err := frontendClient.Do(req)
if err != nil {
t.Fatal(err)
}
expectedResponse := fmt.Sprintf("Hello %s!", frontend.URL)
if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expectedResponse {
t.Errorf("got body '%s'; expected '%s'", string(bodyBytes), expectedResponse)
}
}

func TestNotReplaceHostInNonJSONResponse(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
bodyBytes, _ := io.ReadAll(r.Body)
Expand All @@ -81,6 +115,36 @@ func TestNotReplaceHostInNonJSONResponse(t *testing.T) {
}
}

func TestReplaceHostInRedirectJSONResponse(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set("Content-Type", "application/json")
rw.Write([]byte("Hello world!"))
}))
defer backend.Close()

redirectBackend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
http.Redirect(rw, r, backend.URL, http.StatusFound)
}))
defer redirectBackend.Close()

frontend, _ := createFrontend(redirectBackend.URL)
defer frontend.Close()

frontendClient := frontend.Client()

requestBody := redirectBackend.URL
req, _ := http.NewRequest("GET", frontend.URL, bytes.NewBuffer([]byte(requestBody)))
req.Close = true
res, err := frontendClient.Do(req)
if err != nil {
t.Fatal(err)
}
expectedResponse := "Hello world!"
if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expectedResponse {
t.Errorf("got body '%s'; expected '%s'", string(bodyBytes), expectedResponse)
}
}

func TestAddResponseToCache(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
bodyBytes, _ := io.ReadAll(r.Body)
Expand Down

0 comments on commit fcd932c

Please sign in to comment.