diff --git a/README.md b/README.md index d9b7ea5..0d82f1d 100644 --- a/README.md +++ b/README.md @@ -36,9 +36,15 @@ group.GET("/v1/:id", func(w http.ResponseWriter, r *http.Request, params map[str // UsingContext returns a version of the router or group with context support. ctxGroup := group.UsingContext() // sibling to 'group' node in tree ctxGroup.GET("/v2/:id", func(w http.ResponseWriter, r *http.Request) { - params := httptreemux.ContextParams(r.Context()) + ctxData := httptreemux.ContextData(r.Context()) + params := ctxData.Params() id := params["id"] - fmt.Fprintf(w, "GET /api/v2/%s", id) + + // Useful for middleware to see which route was hit without dealing with wildcards + routePath := ctxData.Route() + + // Prints GET /api/v2/:id id=... + fmt.Fprintf(w, "GET %s id=%s", routePath, id) }) http.ListenAndServe(":8080", router) @@ -58,9 +64,15 @@ router.GET("/:page", func(w http.ResponseWriter, r *http.Request) { group := router.NewGroup("/api") group.GET("/v1/:id", func(w http.ResponseWriter, r *http.Request) { - params := httptreemux.ContextParams(r.Context()) + ctxData := httptreemux.ContextData(r.Context()) + params := ctxData.Params() id := params["id"] - fmt.Fprintf(w, "GET /api/v1/%s", id) + + // Useful for middleware to see which route was hit without dealing with wildcards + routePath := ctxData.Route() + + // Prints GET /api/v1/:id id=... + fmt.Fprintf(w, "GET %s id=%s", routePath, id) }) http.ListenAndServe(":8080", router) diff --git a/context.go b/context.go index 09b537b..18805f4 100644 --- a/context.go +++ b/context.go @@ -58,10 +58,13 @@ func (cg *ContextGroup) NewGroup(path string) *ContextGroup { // Handle allows handling HTTP requests via an http.HandlerFunc, as opposed to an httptreemux.HandlerFunc. // Any parameters from the request URL are stored in a map[string]string in the request's context. func (cg *ContextGroup) Handle(method, path string, handler http.HandlerFunc) { + fullPath := cg.group.path + path cg.group.Handle(method, path, func(w http.ResponseWriter, r *http.Request, params map[string]string) { - if params != nil { - r = r.WithContext(AddParamsToContext(r.Context(), params)) + routeData := contextData{ + route: fullPath, + params: params, } + r = r.WithContext(AddRouteDataToContext(r.Context(), routeData)) handler(w, r) }) } @@ -69,10 +72,13 @@ func (cg *ContextGroup) Handle(method, path string, handler http.HandlerFunc) { // Handler allows handling HTTP requests via an http.Handler interface, as opposed to an httptreemux.HandlerFunc. // Any parameters from the request URL are stored in a map[string]string in the request's context. func (cg *ContextGroup) Handler(method, path string, handler http.Handler) { + fullPath := cg.group.path + path cg.group.Handle(method, path, func(w http.ResponseWriter, r *http.Request, params map[string]string) { - if params != nil { - r = r.WithContext(AddParamsToContext(r.Context(), params)) + routeData := contextData{ + route: fullPath, + params: params, } + r = r.WithContext(AddRouteDataToContext(r.Context(), routeData)) handler.ServeHTTP(w, r) }) } @@ -112,22 +118,61 @@ func (cg *ContextGroup) OPTIONS(path string, handler http.HandlerFunc) { cg.Handle("OPTIONS", path, handler) } +type contextData struct { + route string + params map[string]string +} + +func (cd contextData) Route() string { + return cd.route +} + +func (cd contextData) Params() map[string]string { + if cd.params != nil { + return cd.params + } + return map[string]string{} +} + +// ContextData is the information associated with +type ContextRouteData interface { + Route() string + Params() map[string]string +} + // ContextParams returns the params map associated with the given context if one exists. Otherwise, an empty map is returned. func ContextParams(ctx context.Context) map[string]string { - if p, ok := ctx.Value(paramsContextKey).(map[string]string); ok { - return p + if p, ok := ctx.Value(routeContextKey).(ContextRouteData); ok { + return p.Params() } return map[string]string{} } +// ContextData returns the full route path associated with the given context, without wildcard expansion. +func ContextData(ctx context.Context) ContextRouteData { + if p, ok := ctx.Value(routeContextKey).(ContextRouteData); ok { + return p + } + return nil +} + +func AddRouteDataToContext(ctx context.Context, data ContextRouteData) context.Context { + return context.WithValue(ctx, routeContextKey, data) +} + // AddParamsToContext inserts a parameters map into a context using -// the package's internal context key. Clients of this package should -// really only use this for unit tests. +// the package's internal context key. This function is deprecated. +// Use AddRouteDataToContext instead. func AddParamsToContext(ctx context.Context, params map[string]string) context.Context { - return context.WithValue(ctx, paramsContextKey, params) + data := contextData{ + route: "", + params: params, + } + return AddRouteDataToContext(ctx, data) } type contextKey int -// paramsContextKey is used to retrieve a path's params map from a request's context. -const paramsContextKey contextKey = 0 +// paramsContextKey and routeContextKey are used to retrieve a path's params map and route +// from a request's context. +const routeContextKey contextKey = 0 diff --git a/context_test.go b/context_test.go index 33335b1..48d0eb7 100644 --- a/context_test.go +++ b/context_test.go @@ -4,6 +4,7 @@ package httptreemux import ( "context" + "fmt" "net/http" "net/http/httptest" "testing" @@ -23,8 +24,12 @@ type IContextGroup interface { } func TestContextParams(t *testing.T) { - m := map[string]string{"id": "123"} - ctx := context.WithValue(context.Background(), paramsContextKey, m) + m := contextData{ + params: map[string]string{"id": "123"}, + route: "", + } + + ctx := context.WithValue(context.Background(), routeContextKey, m) params := ContextParams(ctx) if params == nil { @@ -32,96 +37,142 @@ func TestContextParams(t *testing.T) { } if v := params["id"]; v != "123" { - t.Errorf("expected '%s', but got '%#v'", m["id"], params["id"]) + t.Errorf("expected '%s', but got '%#v'", m.params["id"], params["id"]) + } +} + +func TestContextData(t *testing.T) { + p := contextData{ + route: "route/path", + params: map[string]string{"id": "123"}, + } + + ctx := context.WithValue(context.Background(), routeContextKey, p) + + ctxData := ContextData(ctx) + pathValue := ctxData.Route() + if pathValue != p.route { + t.Errorf("expected '%s', but got '%s'", p, pathValue) + } + + params := ctxData.Params() + if v := params["id"]; v != "123" { + t.Errorf("expected '%s', but got '%#v'", p.params["id"], params["id"]) + } +} + +func TestContextDataWithEmptyParams(t *testing.T) { + p := contextData{ + route: "route/path", + params: nil, + } + + ctx := context.WithValue(context.Background(), routeContextKey, p) + params := ContextData(ctx).Params() + if params == nil { + t.Errorf("ContextData.Params should never return nil") } } func TestContextGroupMethods(t *testing.T) { for _, scenario := range scenarios { - t.Log(scenario.description) - testContextGroupMethods(t, scenario.RequestCreator, true, false) - testContextGroupMethods(t, scenario.RequestCreator, false, false) - testContextGroupMethods(t, scenario.RequestCreator, true, true) - testContextGroupMethods(t, scenario.RequestCreator, false, true) + t.Run(scenario.description, func(t *testing.T) { + testContextGroupMethods(t, scenario.RequestCreator, true, false) + testContextGroupMethods(t, scenario.RequestCreator, false, false) + testContextGroupMethods(t, scenario.RequestCreator, true, true) + testContextGroupMethods(t, scenario.RequestCreator, false, true) + }) } } func testContextGroupMethods(t *testing.T, reqGen RequestCreator, headCanUseGet bool, useContextRouter bool) { - t.Logf("Running test: headCanUseGet %v, useContextRouter %v", headCanUseGet, useContextRouter) + t.Run(fmt.Sprintf("headCanUseGet %v, useContextRouter %v", headCanUseGet, useContextRouter), func(t *testing.T) { + var result string + makeHandler := func(method, expectedRoutePath string, hasParam bool) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + result = method + + // Test Legacy Accessor + var v string + v, ok := ContextParams(r.Context())["param"] + if hasParam && !ok { + t.Error("missing key 'param' in context from ContextParams") + } + + ctxData := ContextData(r.Context()) + v, ok = ctxData.Params()["param"] + if hasParam && !ok { + t.Error("missing key 'param' in context from ContextData") + } + + routePath := ctxData.Route() + if routePath != expectedRoutePath { + t.Errorf("Expected context to have route path '%s', saw %s", expectedRoutePath, routePath) + } + + if headCanUseGet && (method == "GET" || v == "HEAD") { + return + } + if hasParam && v != method { + t.Errorf("invalid key 'param' in context; expected '%s' but got '%s'", method, v) + } + } + } - var result string - makeHandler := func(method string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - result = method + var router http.Handler + var rootGroup IContextGroup - v, ok := ContextParams(r.Context())["param"] - if !ok { - t.Error("missing key 'param' in context") - } + if useContextRouter { + root := NewContextMux() + root.HeadCanUseGet = headCanUseGet + t.Log(root.TreeMux.HeadCanUseGet) + router = root + rootGroup = root + } else { + root := New() + root.HeadCanUseGet = headCanUseGet + router = root + rootGroup = root.UsingContext() + } - if headCanUseGet && (method == "GET" || v == "HEAD") { - return + cg := rootGroup.NewGroup("/base").NewGroup("/user") + cg.GET("/:param", makeHandler("GET", cg.group.path+"/:param", true)) + cg.POST("/:param", makeHandler("POST", cg.group.path+"/:param", true)) + cg.PATCH("/PATCH", makeHandler("PATCH", cg.group.path+"/PATCH", false)) + cg.PUT("/:param", makeHandler("PUT", cg.group.path+"/:param", true)) + cg.Handler("DELETE", "/:param", http.HandlerFunc(makeHandler("DELETE", cg.group.path+"/:param", true))) + + testMethod := func(method, expect string) { + result = "" + w := httptest.NewRecorder() + r, _ := reqGen(method, "/base/user/"+method, nil) + router.ServeHTTP(w, r) + if expect == "" && w.Code != http.StatusMethodNotAllowed { + t.Errorf("Method %s not expected to match but saw code %d", method, w.Code) } - if v != method { - t.Errorf("invalid key 'param' in context; expected '%s' but got '%s'", method, v) + if result != expect { + t.Errorf("Method %s got result %s", method, result) } } - } - - var router http.Handler - var rootGroup IContextGroup - - if useContextRouter { - root := NewContextMux() - root.HeadCanUseGet = headCanUseGet - t.Log(root.TreeMux.HeadCanUseGet) - router = root - rootGroup = root - } else { - root := New() - root.HeadCanUseGet = headCanUseGet - router = root - rootGroup = root.UsingContext() - } - cg := rootGroup.NewGroup("/base").NewGroup("/user") - cg.GET("/:param", makeHandler("GET")) - cg.POST("/:param", makeHandler("POST")) - cg.PATCH("/:param", makeHandler("PATCH")) - cg.PUT("/:param", makeHandler("PUT")) - cg.DELETE("/:param", makeHandler("DELETE")) + testMethod("GET", "GET") + testMethod("POST", "POST") + testMethod("PATCH", "PATCH") + testMethod("PUT", "PUT") + testMethod("DELETE", "DELETE") - testMethod := func(method, expect string) { - result = "" - w := httptest.NewRecorder() - r, _ := reqGen(method, "/base/user/"+method, nil) - router.ServeHTTP(w, r) - if expect == "" && w.Code != http.StatusMethodNotAllowed { - t.Errorf("Method %s not expected to match but saw code %d", method, w.Code) - } - - if result != expect { - t.Errorf("Method %s got result %s", method, result) + if headCanUseGet { + t.Log("Test implicit HEAD with HeadCanUseGet = true") + testMethod("HEAD", "GET") + } else { + t.Log("Test implicit HEAD with HeadCanUseGet = false") + testMethod("HEAD", "") } - } - - testMethod("GET", "GET") - testMethod("POST", "POST") - testMethod("PATCH", "PATCH") - testMethod("PUT", "PUT") - testMethod("DELETE", "DELETE") - - if headCanUseGet { - t.Log("Test implicit HEAD with HeadCanUseGet = true") - testMethod("HEAD", "GET") - } else { - t.Log("Test implicit HEAD with HeadCanUseGet = false") - testMethod("HEAD", "") - } - cg.HEAD("/:param", makeHandler("HEAD")) - testMethod("HEAD", "HEAD") + cg.HEAD("/:param", makeHandler("HEAD", cg.group.path+"/:param", true)) + testMethod("HEAD", "HEAD") + }) } func TestNewContextGroup(t *testing.T) {