Skip to content

Commit

Permalink
Add option to thread matched route key on request context
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgeMac committed Nov 21, 2019
1 parent 08a3b3d commit 2b8357f
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 16 deletions.
25 changes: 22 additions & 3 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,17 @@ func ParamsFromContext(ctx context.Context) Params {
return p
}

type matchKey struct{}

// MatchedRouteKey is the request context key under which the handler path match is stored.
var MatchedRouteKey = matchKey{}

// MatchedRouteFromContext retrieves the matched route path from the context.
func MatchedRouteFromContext(ctx context.Context) string {
p, _ := ctx.Value(MatchedRouteKey).(string)
return p
}

// Router is a http.Handler which can be used to dispatch requests to different
// handler functions via configurable routes
type Router struct {
Expand Down Expand Up @@ -186,6 +197,10 @@ type Router struct {
// The handler can be used to keep your server from crashing because of
// unrecovered panics.
PanicHandler func(http.ResponseWriter, *http.Request, interface{})

// AddMatchedRouteToContext when enabled adds the matched router path onto
// the http.Request context before invoking the handler
AddMatchedRouteToContext bool
}

// Make sure the Router conforms with the http.Handler interface
Expand Down Expand Up @@ -354,7 +369,7 @@ func (r *Router) recv(w http.ResponseWriter, req *http.Request) {
// the same path with an extra / without the trailing slash should be performed.
func (r *Router) Lookup(method, path string) (Handle, Params, bool) {
if root := r.trees[method]; root != nil {
handle, ps, tsr := root.getValue(path, r.getParams)
_, handle, ps, tsr := root.getValue(path, r.getParams)
if handle == nil {
r.putParams(ps)
return nil, nil, tsr
Expand Down Expand Up @@ -390,7 +405,7 @@ func (r *Router) allowed(path, reqMethod string) (allow string) {
continue
}

handle, _, _ := r.trees[method].getValue(path, nil)
_, handle, _, _ := r.trees[method].getValue(path, nil)
if handle != nil {
// Add request method to list of allowed methods
allowed = append(allowed, method)
Expand Down Expand Up @@ -426,7 +441,11 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path

if root := r.trees[req.Method]; root != nil {
if handle, ps, tsr := root.getValue(path, r.getParams); handle != nil {
if match, handle, ps, tsr := root.getValue(path, r.getParams); handle != nil {
if r.AddMatchedRouteToContext {
req = req.WithContext(context.WithValue(req.Context(), MatchedRouteKey, match))
}

if ps != nil {
handle(w, req, *ps)
r.putParams(ps)
Expand Down
34 changes: 25 additions & 9 deletions tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,19 @@ const (
catchAll
)

type handleWithFullPath struct {
handle Handle
fullPath string
}

type node struct {
path string
indices string
wildChild bool
nType nodeType
priority uint32
children []*node
handle Handle
handle *handleWithFullPath
}

// Increments priority of the given child and reorders if necessary
Expand Down Expand Up @@ -210,7 +215,9 @@ walk:
if n.handle != nil {
panic("a handle is already registered for path '" + fullPath + "'")
}
n.handle = handle
if handle != nil {
n.handle = &handleWithFullPath{handle: handle, fullPath: fullPath}
}
return
}
}
Expand Down Expand Up @@ -270,7 +277,9 @@ func (n *node) insertChild(path, fullPath string, handle Handle) {
}

// Otherwise we're done. Insert the handle in the new leaf
n.handle = handle
if handle != nil {
n.handle = &handleWithFullPath{handle: handle, fullPath: fullPath}
}
return

} else { // catchAll
Expand Down Expand Up @@ -304,7 +313,7 @@ func (n *node) insertChild(path, fullPath string, handle Handle) {
child = &node{
path: path[i:],
nType: catchAll,
handle: handle,
handle: &handleWithFullPath{handle: handle, fullPath: fullPath},
priority: 1,
}
n.children = []*node{child}
Expand All @@ -315,15 +324,17 @@ func (n *node) insertChild(path, fullPath string, handle Handle) {

// If no wildcard was found, simple insert the path and handle
n.path = path
n.handle = handle
if handle != nil {
n.handle = &handleWithFullPath{handle: handle, fullPath: fullPath}
}
}

// Returns the handle registered with the given path (key). The values of
// wildcards are saved to a map.
// If no handle can be found, a TSR (trailing slash redirect) recommendation is
// made if a handle exists with an extra (without the) trailing slash for the
// given path.
func (n *node) getValue(path string, params func() *Params) (handle Handle, ps *Params, tsr bool) {
func (n *node) getValue(path string, params func() *Params) (matchPath string, handle Handle, ps *Params, tsr bool) {
walk: // Outer loop for walking the tree
for {
prefix := n.path
Expand Down Expand Up @@ -390,7 +401,9 @@ walk: // Outer loop for walking the tree
return
}

if handle = n.handle; handle != nil {
if n.handle != nil {
matchPath = n.handle.fullPath
handle = n.handle.handle
return
} else if len(n.children) == 1 {
// No handle found. Check if a handle for this path + a
Expand All @@ -416,7 +429,8 @@ walk: // Outer loop for walking the tree
}
}

handle = n.handle
matchPath = n.handle.fullPath
handle = n.handle.handle
return

default:
Expand All @@ -426,7 +440,9 @@ walk: // Outer loop for walking the tree
} else if path == prefix {
// We should have reached the node containing the handle.
// Check if this node has a handle registered.
if handle = n.handle; handle != nil {
if n.handle != nil {
matchPath = n.handle.fullPath
handle = n.handle.handle
return
}

Expand Down
12 changes: 8 additions & 4 deletions tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func getParams() *Params {

func checkRequests(t *testing.T, tree *node, requests testRequests) {
for _, request := range requests {
handler, psp, _ := tree.getValue(request.path, getParams)
route, handler, psp, _ := tree.getValue(request.path, getParams)

if handler == nil {
if !request.nilHandler {
Expand All @@ -59,6 +59,10 @@ func checkRequests(t *testing.T, tree *node, requests testRequests) {
if fakeHandlerValue != request.route {
t.Errorf("handle mismatch for route '%s': Wrong handle (%s != %s)", request.path, fakeHandlerValue, request.route)
}

if route != request.route {
t.Errorf("route mismatch for path '%s': Wrong route (%s != %s)", request.path, route, request.route)
}
}

var ps Params
Expand Down Expand Up @@ -427,7 +431,7 @@ func TestTreeTrailingSlashRedirect(t *testing.T) {
"/doc/",
}
for _, route := range tsrRoutes {
handler, _, tsr := tree.getValue(route, nil)
_, handler, _, tsr := tree.getValue(route, nil)
if handler != nil {
t.Fatalf("non-nil handler for TSR route '%s", route)
} else if !tsr {
Expand All @@ -444,7 +448,7 @@ func TestTreeTrailingSlashRedirect(t *testing.T) {
"/api/world/abc",
}
for _, route := range noTsrRoutes {
handler, _, tsr := tree.getValue(route, nil)
_, handler, _, tsr := tree.getValue(route, nil)
if handler != nil {
t.Fatalf("non-nil handler for No-TSR route '%s", route)
} else if tsr {
Expand All @@ -463,7 +467,7 @@ func TestTreeRootTrailingSlashRedirect(t *testing.T) {
t.Fatalf("panic inserting test route: %v", recv)
}

handler, _, tsr := tree.getValue("/", nil)
_, handler, _, tsr := tree.getValue("/", nil)
if handler != nil {
t.Fatalf("non-nil handler")
} else if tsr {
Expand Down

0 comments on commit 2b8357f

Please sign in to comment.