diff --git a/middleware/proxy/README.md b/middleware/proxy/README.md index 6cc03e8b9d..e8663878fc 100644 --- a/middleware/proxy/README.md +++ b/middleware/proxy/README.md @@ -18,6 +18,12 @@ func Balancer(config Config) fiber.Handler func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler // Do performs the given http request and fills the given http response. func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error +// DoRedirects performs the given http request and fills the given http response while following up to maxRedirectsCount redirects. +func DoRedirects(c *fiber.Ctx, addr string, maxRedirectsCount int, clients ...*fasthttp.Client) error +// DoDeadline performs the given request and waits for response until the given deadline. +func DoDeadline(c *fiber.Ctx, addr string, deadline time.Time, clients ...*fasthttp.Client) error +// DoTimeout performs the given request and waits for response during the given timeout duration. +func DoTimeout(c *fiber.Ctx, addr string, timeout time.Duration, clients ...*fasthttp.Client) error // DomainForward the given http request based on the given domain and fills the given http response func DomainForward(hostname string, addr string, clients ...*fasthttp.Client) fiber.Handler // BalancerForward performs the given http request based round robin balancer and fills the given http response @@ -73,6 +79,36 @@ app.Get("/:id", func(c *fiber.Ctx) error { return nil }) +// Make proxy requests while following redirects +app.Get("/proxy", func(c *fiber.Ctx) error { + if err := proxy.DoRedirects(c, "http://google.com", 3); err != nil { + return err + } + // Remove Server header from response + c.Response().Header.Del(fiber.HeaderServer) + return nil +}) + +// Make proxy requests and wait up to 5 seconds before timing out +app.Get("/proxy", func(c *fiber.Ctx) error { + if err := proxy.DoTimeout(c, "http://localhost:3000", time.Second * 5); err != nil { + return err + } + // Remove Server header from response + c.Response().Header.Del(fiber.HeaderServer) + return nil +}) + +// Make proxy requests, timeout a minute from now +app.Get("/proxy", func(c *fiber.Ctx) error { + if err := DoDeadline(c, "http://localhost", time.Now().Add(time.Minute)); err != nil { + return err + } + // Remove Server header from response + c.Response().Header.Del(fiber.HeaderServer) + return nil +}) + // Minimal round robin balancer app.Use(proxy.Balancer(proxy.Config{ Servers: []string{ diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index 008342631f..3a464d48d2 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -7,6 +7,7 @@ import ( "net/url" "strings" "sync" + "time" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/utils" @@ -139,16 +140,53 @@ func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler { // Do performs the given http request and fills the given http response. // This method can be used within a fiber.Handler func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error { + return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error { + return cli.Do(req, resp) + }, clients...) +} + +// DoRedirects performs the given http request and fills the given http response, following up to maxRedirectsCount redirects. +// When the redirect count exceeds maxRedirectsCount, ErrTooManyRedirects is returned. +// This method can be used within a fiber.Handler +func DoRedirects(c *fiber.Ctx, addr string, maxRedirectsCount int, clients ...*fasthttp.Client) error { + return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error { + return cli.DoRedirects(req, resp, maxRedirectsCount) + }, clients...) +} + +// DoDeadline performs the given request and waits for response until the given deadline. +// This method can be used within a fiber.Handler +func DoDeadline(c *fiber.Ctx, addr string, deadline time.Time, clients ...*fasthttp.Client) error { + return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error { + return cli.DoDeadline(req, resp, deadline) + }, clients...) +} + +// DoTimeout performs the given request and waits for response during the given timeout duration. +// This method can be used within a fiber.Handler +func DoTimeout(c *fiber.Ctx, addr string, timeout time.Duration, clients ...*fasthttp.Client) error { + return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error { + return cli.DoTimeout(req, resp, timeout) + }, clients...) +} + +func doAction( + c *fiber.Ctx, + addr string, + action func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error, + clients ...*fasthttp.Client, +) error { var cli *fasthttp.Client + + // set local or global client if len(clients) != 0 { - // Set local client cli = clients[0] } else { - // Set global client lock.RLock() cli = client lock.RUnlock() } + req := c.Request() res := c.Response() originalURL := utils.CopyString(c.OriginalURL()) @@ -157,14 +195,13 @@ func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error { copiedURL := utils.CopyString(addr) req.SetRequestURI(copiedURL) // NOTE: if req.isTLS is true, SetRequestURI keeps the scheme as https. - // issue reference: - // https://github.com/gofiber/fiber/issues/1762 + // Reference: https://github.com/gofiber/fiber/issues/1762 if scheme := getScheme(utils.UnsafeBytes(copiedURL)); len(scheme) > 0 { req.URI().SetSchemeBytes(scheme) } req.Header.Del(fiber.HeaderConnection) - if err := cli.Do(req, res); err != nil { + if err := action(cli, req, res); err != nil { return err } res.Header.Del(fiber.HeaderConnection) diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 6ed169ed69..49be2a2c95 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -2,6 +2,7 @@ package proxy import ( "crypto/tls" + "errors" "io" "net" "net/http/httptest" @@ -48,6 +49,19 @@ func Test_Proxy_Empty_Upstream_Servers(t *testing.T) { app.Use(Balancer(Config{Servers: []string{}})) } +// go test -run Test_Proxy_Empty_Config +func Test_Proxy_Empty_Config(t *testing.T) { + t.Parallel() + + defer func() { + if r := recover(); r != nil { + utils.AssertEqual(t, "Servers cannot be empty", r) + } + }() + app := fiber.New() + app.Use(New(Config{})) +} + // go test -run Test_Proxy_Next func Test_Proxy_Next(t *testing.T) { t.Parallel() @@ -345,24 +359,167 @@ func Test_Proxy_Buffer_Size_Response(t *testing.T) { // go test -race -run Test_Proxy_Do_RestoreOriginalURL func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) { t.Parallel() + _, addr := createProxyTestServer(t, func(c *fiber.Ctx) error { + return c.SendString("proxied") + }) + app := fiber.New() - app.Get("/proxy", func(c *fiber.Ctx) error { - return c.SendString("ok") + app.Get("/test", func(c *fiber.Ctx) error { + return Do(c, "http://"+addr) }) + resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) + utils.AssertEqual(t, nil, err1) + utils.AssertEqual(t, "/test", resp.Request.URL.String()) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "proxied", string(body)) +} + +// go test -race -run Test_Proxy_Do_WithRealURL +func Test_Proxy_Do_WithRealURL(t *testing.T) { + t.Parallel() + app := fiber.New() app.Get("/test", func(c *fiber.Ctx) error { - originalURL := utils.CopyString(c.OriginalURL()) - if err := Do(c, "/proxy"); err != nil { - return err - } - utils.AssertEqual(t, originalURL, c.OriginalURL()) - return c.SendString("ok") + return Do(c, "https://www.google.com") + }) + + resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) + utils.AssertEqual(t, nil, err1) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) + utils.AssertEqual(t, "/test", resp.Request.URL.String()) + body, err := io.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, true, strings.Contains(string(body), "https://www.google.com/")) +} + +// go test -race -run Test_Proxy_Do_WithRedirect +func Test_Proxy_Do_WithRedirect(t *testing.T) { + t.Parallel() + app := fiber.New() + app.Get("/test", func(c *fiber.Ctx) error { + return Do(c, "https://google.com") + }) + + resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) + utils.AssertEqual(t, nil, err1) + body, err := io.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, true, strings.Contains(string(body), "https://www.google.com/")) + utils.AssertEqual(t, 301, resp.StatusCode) +} + +// go test -race -run Test_Proxy_DoRedirects_RestoreOriginalURL +func Test_Proxy_DoRedirects_RestoreOriginalURL(t *testing.T) { + t.Parallel() + app := fiber.New() + app.Get("/test", func(c *fiber.Ctx) error { + return DoRedirects(c, "http://google.com", 1) + }) + + resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) + utils.AssertEqual(t, nil, err1) + _, err := io.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) + utils.AssertEqual(t, "/test", resp.Request.URL.String()) +} + +// go test -race -run Test_Proxy_DoRedirects_TooManyRedirects +func Test_Proxy_DoRedirects_TooManyRedirects(t *testing.T) { + t.Parallel() + app := fiber.New() + app.Get("/test", func(c *fiber.Ctx) error { + return DoRedirects(c, "http://google.com", 0) }) + + resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) + utils.AssertEqual(t, nil, err1) + body, err := io.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "too many redirects detected when doing the request", string(body)) + utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode) + utils.AssertEqual(t, "/test", resp.Request.URL.String()) +} + +// go test -race -run Test_Proxy_DoTimeout_RestoreOriginalURL +func Test_Proxy_DoTimeout_RestoreOriginalURL(t *testing.T) { + t.Parallel() + + _, addr := createProxyTestServer(t, func(c *fiber.Ctx) error { + return c.SendString("proxied") + }) + + app := fiber.New() + app.Get("/test", func(c *fiber.Ctx) error { + return DoTimeout(c, "http://"+addr, time.Second) + }) + + resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) + utils.AssertEqual(t, nil, err1) + body, err := io.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "proxied", string(body)) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) + utils.AssertEqual(t, "/test", resp.Request.URL.String()) +} + +// go test -race -run Test_Proxy_DoTimeout_Timeout +func Test_Proxy_DoTimeout_Timeout(t *testing.T) { + t.Parallel() + + _, addr := createProxyTestServer(t, func(c *fiber.Ctx) error { + time.Sleep(time.Second * 5) + return c.SendString("proxied") + }) + + app := fiber.New() + app.Get("/test", func(c *fiber.Ctx) error { + return DoTimeout(c, "http://"+addr, time.Second) + }) + _, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) - // This test requires multiple requests due to zero allocation used in fiber - _, err2 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) + utils.AssertEqual(t, errors.New("test: timeout error 1000ms"), err1) +} +// go test -race -run Test_Proxy_DoDeadline_RestoreOriginalURL +func Test_Proxy_DoDeadline_RestoreOriginalURL(t *testing.T) { + t.Parallel() + + _, addr := createProxyTestServer(t, func(c *fiber.Ctx) error { + return c.SendString("proxied") + }) + + app := fiber.New() + app.Get("/test", func(c *fiber.Ctx) error { + return DoDeadline(c, "http://"+addr, time.Now().Add(time.Second)) + }) + + resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) utils.AssertEqual(t, nil, err1) - utils.AssertEqual(t, nil, err2) + body, err := io.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "proxied", string(body)) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) + utils.AssertEqual(t, "/test", resp.Request.URL.String()) +} + +// go test -race -run Test_Proxy_DoDeadline_PastDeadline +func Test_Proxy_DoDeadline_PastDeadline(t *testing.T) { + t.Parallel() + + _, addr := createProxyTestServer(t, func(c *fiber.Ctx) error { + time.Sleep(time.Second * 5) + return c.SendString("proxied") + }) + + app := fiber.New() + app.Get("/test", func(c *fiber.Ctx) error { + return DoDeadline(c, "http://"+addr, time.Now().Add(time.Second)) + }) + + _, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) + utils.AssertEqual(t, errors.New("test: timeout error 1000ms"), err1) } // go test -race -run Test_Proxy_Do_HTTP_Prefix_URL