Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to thread matched route key on request context #286

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,17 @@ func ParamsFromContext(ctx context.Context) Params {
return p
}

type matchKey struct{}

// MatchedRouteKey is the request context key under which the handler path match is stored.
var MatchedRouteKey = matchKey{}

// MatchedRouteFromContext retrieves the matched route path from the context.
func MatchedRouteFromContext(ctx context.Context) string {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it's even consistent in the existing code, but generally a route should be a tuple (Method, Path) in the context of this router. This is just the matched path.

p, _ := ctx.Value(MatchedRouteKey).(string)
return p
}

// Router is a http.Handler which can be used to dispatch requests to different
// handler functions via configurable routes
type Router struct {
Expand Down Expand Up @@ -186,6 +197,10 @@ type Router struct {
// The handler can be used to keep your server from crashing because of
// unrecovered panics.
PanicHandler func(http.ResponseWriter, *http.Request, interface{})

// AddMatchedRouteToContext when enabled adds the matched router path onto
// the http.Request context before invoking the handler
AddMatchedRouteToContext bool
}

// Make sure the Router conforms with the http.Handler interface
Expand Down Expand Up @@ -354,7 +369,7 @@ func (r *Router) recv(w http.ResponseWriter, req *http.Request) {
// the same path with an extra / without the trailing slash should be performed.
func (r *Router) Lookup(method, path string) (Handle, Params, bool) {
if root := r.trees[method]; root != nil {
handle, ps, tsr := root.getValue(path, r.getParams)
_, handle, ps, tsr := root.getValue(path, r.getParams)
if handle == nil {
r.putParams(ps)
return nil, nil, tsr
Expand Down Expand Up @@ -390,7 +405,7 @@ func (r *Router) allowed(path, reqMethod string) (allow string) {
continue
}

handle, _, _ := r.trees[method].getValue(path, nil)
_, handle, _, _ := r.trees[method].getValue(path, nil)
if handle != nil {
// Add request method to list of allowed methods
allowed = append(allowed, method)
Expand Down Expand Up @@ -426,7 +441,11 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path

if root := r.trees[req.Method]; root != nil {
if handle, ps, tsr := root.getValue(path, r.getParams); handle != nil {
if match, handle, ps, tsr := root.getValue(path, r.getParams); handle != nil {
if r.AddMatchedRouteToContext {
req = req.WithContext(context.WithValue(req.Context(), MatchedRouteKey, match))
}

if ps != nil {
handle(w, req, *ps)
r.putParams(ps)
Expand Down
26 changes: 26 additions & 0 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,32 @@ func TestRouterParamsFromContext(t *testing.T) {
}
}

func TestRouterMatchedRouteFromContext(t *testing.T) {
routed := false

handlerFunc := func(_ http.ResponseWriter, req *http.Request) {
// get params from request context
route := MatchedRouteFromContext(req.Context())

if route != "/user/:name" {
t.Fatalf("Wrong matched route: want /user/:name, got %q", route)
}

routed = true
}

router := New()
router.AddMatchedRouteToContext = true
router.HandlerFunc(http.MethodGet, "/user/:name", handlerFunc)

w := new(mockResponseWriter)
r, _ := http.NewRequest(http.MethodGet, "/user/gopher", nil)
router.ServeHTTP(w, r)
if !routed {
t.Fatal("Routing failed!")
}
}

type mockFileSystem struct {
opened bool
}
Expand Down
13 changes: 12 additions & 1 deletion tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ type node struct {
priority uint32
children []*node
handle Handle
fullPath string
}

// Increments priority of the given child and reorders if necessary
Expand Down Expand Up @@ -134,6 +135,7 @@ walk:
indices: n.indices,
children: n.children,
handle: n.handle,
fullPath: n.fullPath,
priority: n.priority - 1,
}

Expand Down Expand Up @@ -210,7 +212,10 @@ walk:
if n.handle != nil {
panic("a handle is already registered for path '" + fullPath + "'")
}

n.handle = handle
n.fullPath = fullPath

return
}
}
Expand Down Expand Up @@ -271,6 +276,7 @@ func (n *node) insertChild(path, fullPath string, handle Handle) {

// Otherwise we're done. Insert the handle in the new leaf
n.handle = handle
n.fullPath = fullPath
return

} else { // catchAll
Expand Down Expand Up @@ -305,6 +311,7 @@ func (n *node) insertChild(path, fullPath string, handle Handle) {
path: path[i:],
nType: catchAll,
handle: handle,
fullPath: fullPath,
priority: 1,
}
n.children = []*node{child}
Expand All @@ -316,14 +323,15 @@ func (n *node) insertChild(path, fullPath string, handle Handle) {
// If no wildcard was found, simple insert the path and handle
n.path = path
n.handle = handle
n.fullPath = fullPath
}

// Returns the handle registered with the given path (key). The values of
// wildcards are saved to a map.
// If no handle can be found, a TSR (trailing slash redirect) recommendation is
// made if a handle exists with an extra (without the) trailing slash for the
// given path.
func (n *node) getValue(path string, params func() *Params) (handle Handle, ps *Params, tsr bool) {
func (n *node) getValue(path string, params func() *Params) (matchPath string, handle Handle, ps *Params, tsr bool) {
walk: // Outer loop for walking the tree
for {
prefix := n.path
Expand Down Expand Up @@ -391,6 +399,7 @@ walk: // Outer loop for walking the tree
}

if handle = n.handle; handle != nil {
matchPath = n.fullPath
return
} else if len(n.children) == 1 {
// No handle found. Check if a handle for this path + a
Expand All @@ -416,6 +425,7 @@ walk: // Outer loop for walking the tree
}
}

matchPath = n.fullPath
handle = n.handle
return

Expand All @@ -427,6 +437,7 @@ walk: // Outer loop for walking the tree
// We should have reached the node containing the handle.
// Check if this node has a handle registered.
if handle = n.handle; handle != nil {
matchPath = n.fullPath
return
}

Expand Down
12 changes: 8 additions & 4 deletions tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func getParams() *Params {

func checkRequests(t *testing.T, tree *node, requests testRequests) {
for _, request := range requests {
handler, psp, _ := tree.getValue(request.path, getParams)
route, handler, psp, _ := tree.getValue(request.path, getParams)

if handler == nil {
if !request.nilHandler {
Expand All @@ -59,6 +59,10 @@ func checkRequests(t *testing.T, tree *node, requests testRequests) {
if fakeHandlerValue != request.route {
t.Errorf("handle mismatch for route '%s': Wrong handle (%s != %s)", request.path, fakeHandlerValue, request.route)
}

if route != request.route {
t.Errorf("route mismatch for path '%s': Wrong route (%s != %s)", request.path, route, request.route)
}
}

var ps Params
Expand Down Expand Up @@ -427,7 +431,7 @@ func TestTreeTrailingSlashRedirect(t *testing.T) {
"/doc/",
}
for _, route := range tsrRoutes {
handler, _, tsr := tree.getValue(route, nil)
_, handler, _, tsr := tree.getValue(route, nil)
if handler != nil {
t.Fatalf("non-nil handler for TSR route '%s", route)
} else if !tsr {
Expand All @@ -444,7 +448,7 @@ func TestTreeTrailingSlashRedirect(t *testing.T) {
"/api/world/abc",
}
for _, route := range noTsrRoutes {
handler, _, tsr := tree.getValue(route, nil)
_, handler, _, tsr := tree.getValue(route, nil)
if handler != nil {
t.Fatalf("non-nil handler for No-TSR route '%s", route)
} else if tsr {
Expand All @@ -463,7 +467,7 @@ func TestTreeRootTrailingSlashRedirect(t *testing.T) {
t.Fatalf("panic inserting test route: %v", recv)
}

handler, _, tsr := tree.getValue("/", nil)
_, handler, _, tsr := tree.getValue("/", nil)
if handler != nil {
t.Fatalf("non-nil handler")
} else if tsr {
Expand Down