diff --git a/router.go b/router.go index e0db1dc..b8063e4 100644 --- a/router.go +++ b/router.go @@ -109,6 +109,7 @@ func redirect(w http.ResponseWriter, r *http.Request, newPath string, statusCode func (t *TreeMux) lookup(w http.ResponseWriter, r *http.Request) (result LookupResult, found bool) { result.StatusCode = http.StatusNotFound path := r.RequestURI + unescapedPath := r.URL.Path pathLen := len(path) if pathLen > 0 && t.PathSource == RequestURI { rawQueryLen := len(r.URL.RawQuery) @@ -128,6 +129,7 @@ func (t *TreeMux) lookup(w http.ResponseWriter, r *http.Request) (result LookupR trailingSlash := path[pathLen-1] == '/' && pathLen > 1 if trailingSlash && t.RedirectTrailingSlash { path = path[:pathLen-1] + unescapedPath = unescapedPath[:len(unescapedPath)-1] } n, handler, params := t.root.search(r.Method, path[1:]) @@ -135,7 +137,7 @@ func (t *TreeMux) lookup(w http.ResponseWriter, r *http.Request) (result LookupR if t.RedirectCleanPath { // Path was not found. Try cleaning it up and search again. // TODO Test this - cleanPath := Clean(path) + cleanPath := Clean(unescapedPath) n, handler, params = t.root.search(r.Method, cleanPath[1:]) if n == nil { // Still nothing found. @@ -169,11 +171,11 @@ func (t *TreeMux) lookup(w http.ResponseWriter, r *http.Request) (result LookupR var h HandlerFunc if n.addSlash { // Need to add a slash. - h = redirectHandler(path+"/", statusCode) + h = redirectHandler(unescapedPath+"/", statusCode) } else if path != "/" { // We need to remove the slash. This was already done at the // beginning of the function. - h = redirectHandler(path, statusCode) + h = redirectHandler(unescapedPath, statusCode) } if h != nil { diff --git a/router_test.go b/router_test.go index 1987693..e935407 100644 --- a/router_test.go +++ b/router_test.go @@ -22,7 +22,7 @@ func panicHandler(w http.ResponseWriter, r *http.Request, params map[string]stri func newRequest(method, path string, body io.Reader) (*http.Request, error) { r, _ := http.NewRequest(method, path, body) - u, _ := url.Parse(path) + u, _ := url.ParseRequestURI(path) r.URL = u r.RequestURI = path return r, nil @@ -428,7 +428,7 @@ func testRedirect(t *testing.T, defaultBehavior, getBehavior, postBehavior Redir t.Errorf("/noslash/ expected code %d, saw %d", expectedCode, w.Code) } if expectedCode != http.StatusNoContent && w.Header().Get("Location") != "/noslash" { - t.Errorf("/noslash/ was not redirected to /noslash") + t.Errorf("/noslash/ was redirected to `%s` instead of /noslash", w.Header().Get("Location")) } r, _ = newRequest(method, "//noslash/", nil) @@ -1043,6 +1043,36 @@ func TestLookup(t *testing.T) { tryLookup("POST", "/user/dimfeld/", true, http.StatusTemporaryRedirect) } +func TestRedirectEscapedPath(t *testing.T) { + router := New() + + testHandler := func(w http.ResponseWriter, r *http.Request, params map[string]string) {} + + router.GET("/:escaped/", testHandler) + + w := httptest.NewRecorder() + u, err := url.Parse("/Test P@th") + if err != nil { + t.Error(err) + return + } + + r, _ := newRequest("GET", u.String(), nil) + + router.ServeHTTP(w, r) + + if w.Code != http.StatusMovedPermanently { + t.Errorf("Expected status 301 but saw %d", w.Code) + } + + path := w.Header().Get("Location") + expected := "/Test%20P@th/" + if path != expected { + t.Errorf("Given path wasn't escaped correctly.\n"+ + "Expected: %q\nBut got: %q", expected, path) + } +} + func BenchmarkRouterSimple(b *testing.B) { router := New()