From 34e411acc5017cc0f5b1a447ed46757b02bb06a8 Mon Sep 17 00:00:00 2001 From: George MacRorie Date: Thu, 21 Nov 2019 14:46:54 +0100 Subject: [PATCH 1/2] Add option to thread matched route key on request context --- router.go | 25 ++++++++++++++++++++++--- router_test.go | 26 ++++++++++++++++++++++++++ tree.go | 18 +++++++++++++++++- tree_test.go | 12 ++++++++---- 4 files changed, 73 insertions(+), 8 deletions(-) diff --git a/router.go b/router.go index 088ceec0..689fcaf5 100644 --- a/router.go +++ b/router.go @@ -122,6 +122,17 @@ func ParamsFromContext(ctx context.Context) Params { return p } +type matchKey struct{} + +// MatchedRouteKey is the request context key under which the handler path match is stored. +var MatchedRouteKey = matchKey{} + +// MatchedRouteFromContext retrieves the matched route path from the context. +func MatchedRouteFromContext(ctx context.Context) string { + p, _ := ctx.Value(MatchedRouteKey).(string) + return p +} + // Router is a http.Handler which can be used to dispatch requests to different // handler functions via configurable routes type Router struct { @@ -186,6 +197,10 @@ type Router struct { // The handler can be used to keep your server from crashing because of // unrecovered panics. PanicHandler func(http.ResponseWriter, *http.Request, interface{}) + + // AddMatchedRouteToContext when enabled adds the matched router path onto + // the http.Request context before invoking the handler + AddMatchedRouteToContext bool } // Make sure the Router conforms with the http.Handler interface @@ -354,7 +369,7 @@ func (r *Router) recv(w http.ResponseWriter, req *http.Request) { // the same path with an extra / without the trailing slash should be performed. func (r *Router) Lookup(method, path string) (Handle, Params, bool) { if root := r.trees[method]; root != nil { - handle, ps, tsr := root.getValue(path, r.getParams) + _, handle, ps, tsr := root.getValue(path, r.getParams) if handle == nil { r.putParams(ps) return nil, nil, tsr @@ -390,7 +405,7 @@ func (r *Router) allowed(path, reqMethod string) (allow string) { continue } - handle, _, _ := r.trees[method].getValue(path, nil) + _, handle, _, _ := r.trees[method].getValue(path, nil) if handle != nil { // Add request method to list of allowed methods allowed = append(allowed, method) @@ -426,7 +441,11 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { path := req.URL.Path if root := r.trees[req.Method]; root != nil { - if handle, ps, tsr := root.getValue(path, r.getParams); handle != nil { + if match, handle, ps, tsr := root.getValue(path, r.getParams); handle != nil { + if r.AddMatchedRouteToContext { + req = req.WithContext(context.WithValue(req.Context(), MatchedRouteKey, match)) + } + if ps != nil { handle(w, req, *ps) r.putParams(ps) diff --git a/router_test.go b/router_test.go index 2aa18598..d01fa0ac 100644 --- a/router_test.go +++ b/router_test.go @@ -610,6 +610,32 @@ func TestRouterParamsFromContext(t *testing.T) { } } +func TestRouterMatchedRouteFromContext(t *testing.T) { + routed := false + + handlerFunc := func(_ http.ResponseWriter, req *http.Request) { + // get params from request context + route := MatchedRouteFromContext(req.Context()) + + if route != "/user/:name" { + t.Fatalf("Wrong matched route: want /user/:name, got %q", route) + } + + routed = true + } + + router := New() + router.AddMatchedRouteToContext = true + router.HandlerFunc(http.MethodGet, "/user/:name", handlerFunc) + + w := new(mockResponseWriter) + r, _ := http.NewRequest(http.MethodGet, "/user/gopher", nil) + router.ServeHTTP(w, r) + if !routed { + t.Fatal("Routing failed!") + } +} + type mockFileSystem struct { opened bool } diff --git a/tree.go b/tree.go index 94c35f43..26867534 100644 --- a/tree.go +++ b/tree.go @@ -71,6 +71,11 @@ const ( catchAll ) +type handleWithFullPath struct { + handle Handle + fullPath string +} + type node struct { path string indices string @@ -79,6 +84,7 @@ type node struct { priority uint32 children []*node handle Handle + fullPath string } // Increments priority of the given child and reorders if necessary @@ -134,6 +140,7 @@ walk: indices: n.indices, children: n.children, handle: n.handle, + fullPath: n.fullPath, priority: n.priority - 1, } @@ -210,7 +217,10 @@ walk: if n.handle != nil { panic("a handle is already registered for path '" + fullPath + "'") } + n.handle = handle + n.fullPath = fullPath + return } } @@ -271,6 +281,7 @@ func (n *node) insertChild(path, fullPath string, handle Handle) { // Otherwise we're done. Insert the handle in the new leaf n.handle = handle + n.fullPath = fullPath return } else { // catchAll @@ -305,6 +316,7 @@ func (n *node) insertChild(path, fullPath string, handle Handle) { path: path[i:], nType: catchAll, handle: handle, + fullPath: fullPath, priority: 1, } n.children = []*node{child} @@ -316,6 +328,7 @@ func (n *node) insertChild(path, fullPath string, handle Handle) { // If no wildcard was found, simple insert the path and handle n.path = path n.handle = handle + n.fullPath = fullPath } // Returns the handle registered with the given path (key). The values of @@ -323,7 +336,7 @@ func (n *node) insertChild(path, fullPath string, handle Handle) { // If no handle can be found, a TSR (trailing slash redirect) recommendation is // made if a handle exists with an extra (without the) trailing slash for the // given path. -func (n *node) getValue(path string, params func() *Params) (handle Handle, ps *Params, tsr bool) { +func (n *node) getValue(path string, params func() *Params) (matchPath string, handle Handle, ps *Params, tsr bool) { walk: // Outer loop for walking the tree for { prefix := n.path @@ -391,6 +404,7 @@ walk: // Outer loop for walking the tree } if handle = n.handle; handle != nil { + matchPath = n.fullPath return } else if len(n.children) == 1 { // No handle found. Check if a handle for this path + a @@ -416,6 +430,7 @@ walk: // Outer loop for walking the tree } } + matchPath = n.fullPath handle = n.handle return @@ -427,6 +442,7 @@ walk: // Outer loop for walking the tree // We should have reached the node containing the handle. // Check if this node has a handle registered. if handle = n.handle; handle != nil { + matchPath = n.fullPath return } diff --git a/tree_test.go b/tree_test.go index c604b742..518790d8 100644 --- a/tree_test.go +++ b/tree_test.go @@ -46,7 +46,7 @@ func getParams() *Params { func checkRequests(t *testing.T, tree *node, requests testRequests) { for _, request := range requests { - handler, psp, _ := tree.getValue(request.path, getParams) + route, handler, psp, _ := tree.getValue(request.path, getParams) if handler == nil { if !request.nilHandler { @@ -59,6 +59,10 @@ func checkRequests(t *testing.T, tree *node, requests testRequests) { if fakeHandlerValue != request.route { t.Errorf("handle mismatch for route '%s': Wrong handle (%s != %s)", request.path, fakeHandlerValue, request.route) } + + if route != request.route { + t.Errorf("route mismatch for path '%s': Wrong route (%s != %s)", request.path, route, request.route) + } } var ps Params @@ -427,7 +431,7 @@ func TestTreeTrailingSlashRedirect(t *testing.T) { "/doc/", } for _, route := range tsrRoutes { - handler, _, tsr := tree.getValue(route, nil) + _, handler, _, tsr := tree.getValue(route, nil) if handler != nil { t.Fatalf("non-nil handler for TSR route '%s", route) } else if !tsr { @@ -444,7 +448,7 @@ func TestTreeTrailingSlashRedirect(t *testing.T) { "/api/world/abc", } for _, route := range noTsrRoutes { - handler, _, tsr := tree.getValue(route, nil) + _, handler, _, tsr := tree.getValue(route, nil) if handler != nil { t.Fatalf("non-nil handler for No-TSR route '%s", route) } else if tsr { @@ -463,7 +467,7 @@ func TestTreeRootTrailingSlashRedirect(t *testing.T) { t.Fatalf("panic inserting test route: %v", recv) } - handler, _, tsr := tree.getValue("/", nil) + _, handler, _, tsr := tree.getValue("/", nil) if handler != nil { t.Fatalf("non-nil handler") } else if tsr { From 148b36ae2a0356d0563e597f3ba77560c3c0275d Mon Sep 17 00:00:00 2001 From: George MacRorie Date: Thu, 21 Nov 2019 17:21:37 +0100 Subject: [PATCH 2/2] Remove unused type added by mistake --- tree.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tree.go b/tree.go index 26867534..b177afe3 100644 --- a/tree.go +++ b/tree.go @@ -71,11 +71,6 @@ const ( catchAll ) -type handleWithFullPath struct { - handle Handle - fullPath string -} - type node struct { path string indices string