diff --git a/echo.go b/echo.go index 9924ac86d..5a1ed6452 100644 --- a/echo.go +++ b/echo.go @@ -206,24 +206,28 @@ const ( // advertised as supported by the target resource. Returning an Allow header is mandatory // for status 405 (method not found) and useful for the OPTIONS method in responses. // See RFC 7231: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1 - HeaderAllow = "Allow" - HeaderAuthorization = "Authorization" - HeaderContentDisposition = "Content-Disposition" - HeaderContentEncoding = "Content-Encoding" - HeaderContentLength = "Content-Length" - HeaderContentType = "Content-Type" - HeaderCookie = "Cookie" - HeaderSetCookie = "Set-Cookie" - HeaderIfModifiedSince = "If-Modified-Since" - HeaderLastModified = "Last-Modified" - HeaderLocation = "Location" - HeaderRetryAfter = "Retry-After" - HeaderUpgrade = "Upgrade" - HeaderVary = "Vary" - HeaderWWWAuthenticate = "WWW-Authenticate" - HeaderXForwardedFor = "X-Forwarded-For" - HeaderXForwardedProto = "X-Forwarded-Proto" - HeaderXForwardedProtocol = "X-Forwarded-Protocol" + HeaderAllow = "Allow" + HeaderAuthorization = "Authorization" + HeaderContentDisposition = "Content-Disposition" + HeaderContentEncoding = "Content-Encoding" + HeaderContentLength = "Content-Length" + HeaderContentType = "Content-Type" + HeaderCookie = "Cookie" + HeaderForwarded = "Forwarded" + HeaderSetCookie = "Set-Cookie" + HeaderIfModifiedSince = "If-Modified-Since" + HeaderLastModified = "Last-Modified" + HeaderLocation = "Location" + HeaderRetryAfter = "Retry-After" + HeaderUpgrade = "Upgrade" + HeaderVary = "Vary" + HeaderWWWAuthenticate = "WWW-Authenticate" + HeaderXForwardedFor = "X-Forwarded-For" + HeaderXForwardedHost = "X-Forwarded-Host" + HeaderXForwardedPrefix = "X-Forwarded-Prefix" + HeaderXForwardedProto = "X-Forwarded-Proto" + HeaderXForwardedProtocol = "X-Forwarded-Protocol" + HeaderXForwardedSsl = "X-Forwarded-Ssl" HeaderXUrlScheme = "X-Url-Scheme" HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" diff --git a/middleware/proxy_headers.go b/middleware/proxy_headers.go new file mode 100644 index 000000000..1353171b2 --- /dev/null +++ b/middleware/proxy_headers.go @@ -0,0 +1,58 @@ +package middleware + +import ( + "net/http" + "net/url" + "regexp" + "strings" + + "github.com/labstack/echo/v4" +) + +var ( + protoRegex = regexp.MustCompile(`(?i)(?:proto=)(https|http)`) + ipRegex = regexp.MustCompile("(?i)(?:for=)([^(;|,| )]+)") +) + +func ProxyHeaders() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if fwd := c.Request().Header.Get(echo.HeaderForwarded); fwd != "" { + if match := ipRegex.FindStringSubmatch(fwd); len(match) > 1 { + c.Request().RemoteAddr = strings.Trim(match[1], `"`) + } + } else if fwd := c.RealIP(); fwd != "" { + c.Request().RemoteAddr = fwd + } + + if scheme := getScheme(c.Request()); scheme != "" { + c.Request().URL.Scheme = scheme + } + + if c.Request().Header.Get(echo.HeaderXForwardedHost) != "" { + c.Request().Host = c.Request().Header.Get(echo.HeaderXForwardedHost) + } + + if prefix := c.Request().Header.Get(echo.HeaderXForwardedPrefix); prefix != "" { + c.Request().RequestURI, _ = url.JoinPath(prefix, c.Request().RequestURI) + c.Request().URL.Path, _ = url.JoinPath(prefix, c.Request().URL.Path) + } + return next(c) + } + } +} + +func getScheme(r *http.Request) string { + var scheme string + + if proto := r.Header.Get(echo.HeaderXForwardedProto); proto != "" { + scheme = strings.ToLower(proto) + } else if proto := r.Header.Get(echo.HeaderXForwardedProtocol); proto != "" { + scheme = strings.ToLower(proto) + } else if proto = r.Header.Get(echo.HeaderForwarded); proto != "" { + if match := protoRegex.FindStringSubmatch(proto); len(match) > 1 { + scheme = strings.ToLower(match[1]) + } + } + return scheme +} diff --git a/middleware/proxy_headers_test.go b/middleware/proxy_headers_test.go new file mode 100644 index 000000000..3226a2eda --- /dev/null +++ b/middleware/proxy_headers_test.go @@ -0,0 +1,78 @@ +package middleware + +import ( + "net/http" + "testing" +) + +func Test_getScheme(t *testing.T) { + tests := []struct { + name string + r *http.Request + headerName string + whenHeader string + want string + }{ + { + name: "test only X-Forwarded-Proto: https", + headerName: "X-Forwarded-Proto", + whenHeader: "https", + want: "https", + }, + { + name: "test only X-Forwarded-Proto: http", + headerName: "X-Forwarded-Proto", + whenHeader: "http", + want: "http", + }, + { + name: "test only X-Forwarded-Proto: HTTP", + headerName: "X-Forwarded-Proto", + whenHeader: "HTTP", + want: "http", + }, + { + name: "test only X-Forwarded-Protocol: https", + headerName: "X-Forwarded-Protocol", + whenHeader: "https", + want: "https", + }, + { + name: "test only X-Forwarded-Protocol: http", + headerName: "X-Forwarded-Protocol", + whenHeader: "http", + want: "http", + }, + { + name: "test only X-Forwarded-Protocol: HTTP", + headerName: "X-Forwarded-Protocol", + whenHeader: "HTTP", + want: "http", + }, + { + name: "test only Forwarded https", + headerName: "Forwarded", + whenHeader: "proto=https", + want: "https", + }, + { + name: "test only Forwarded: http", + headerName: "Forwarded", + whenHeader: "proto=http", + want: "http", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &http.Request{ + Header: http.Header{ + tt.headerName: []string{tt.whenHeader}, + }, + } + + if got := getScheme(req); got != tt.want { + t.Errorf("getScheme() = %v, want %v", got, tt.want) + } + }) + } +}