Skip to content

Commit

Permalink
Put the route path in the request context (#75)
Browse files Browse the repository at this point in the history
* Put full route path in the context

* Update to include routes with no params (#74)

Co-authored-by: Jacob Stuart <[email protected]>

* Combine context route path and params into a single interface

Co-authored-by: stuartclan <[email protected]>
Co-authored-by: Jacob Stuart <[email protected]>
  • Loading branch information
3 people authored Apr 16, 2020
1 parent 8cec559 commit e7c8e18
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 87 deletions.
20 changes: 16 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,15 @@ group.GET("/v1/:id", func(w http.ResponseWriter, r *http.Request, params map[str
// UsingContext returns a version of the router or group with context support.
ctxGroup := group.UsingContext() // sibling to 'group' node in tree
ctxGroup.GET("/v2/:id", func(w http.ResponseWriter, r *http.Request) {
params := httptreemux.ContextParams(r.Context())
ctxData := httptreemux.ContextData(r.Context())
params := ctxData.Params()
id := params["id"]
fmt.Fprintf(w, "GET /api/v2/%s", id)

// Useful for middleware to see which route was hit without dealing with wildcards
routePath := ctxData.Route()

// Prints GET /api/v2/:id id=...
fmt.Fprintf(w, "GET %s id=%s", routePath, id)
})

http.ListenAndServe(":8080", router)
Expand All @@ -58,9 +64,15 @@ router.GET("/:page", func(w http.ResponseWriter, r *http.Request) {

group := router.NewGroup("/api")
group.GET("/v1/:id", func(w http.ResponseWriter, r *http.Request) {
params := httptreemux.ContextParams(r.Context())
ctxData := httptreemux.ContextData(r.Context())
params := ctxData.Params()
id := params["id"]
fmt.Fprintf(w, "GET /api/v1/%s", id)

// Useful for middleware to see which route was hit without dealing with wildcards
routePath := ctxData.Route()

// Prints GET /api/v1/:id id=...
fmt.Fprintf(w, "GET %s id=%s", routePath, id)
})

http.ListenAndServe(":8080", router)
Expand Down
67 changes: 56 additions & 11 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,27 @@ func (cg *ContextGroup) NewGroup(path string) *ContextGroup {
// Handle allows handling HTTP requests via an http.HandlerFunc, as opposed to an httptreemux.HandlerFunc.
// Any parameters from the request URL are stored in a map[string]string in the request's context.
func (cg *ContextGroup) Handle(method, path string, handler http.HandlerFunc) {
fullPath := cg.group.path + path
cg.group.Handle(method, path, func(w http.ResponseWriter, r *http.Request, params map[string]string) {
if params != nil {
r = r.WithContext(AddParamsToContext(r.Context(), params))
routeData := contextData{
route: fullPath,
params: params,
}
r = r.WithContext(AddRouteDataToContext(r.Context(), routeData))
handler(w, r)
})
}

// Handler allows handling HTTP requests via an http.Handler interface, as opposed to an httptreemux.HandlerFunc.
// Any parameters from the request URL are stored in a map[string]string in the request's context.
func (cg *ContextGroup) Handler(method, path string, handler http.Handler) {
fullPath := cg.group.path + path
cg.group.Handle(method, path, func(w http.ResponseWriter, r *http.Request, params map[string]string) {
if params != nil {
r = r.WithContext(AddParamsToContext(r.Context(), params))
routeData := contextData{
route: fullPath,
params: params,
}
r = r.WithContext(AddRouteDataToContext(r.Context(), routeData))
handler.ServeHTTP(w, r)
})
}
Expand Down Expand Up @@ -112,22 +118,61 @@ func (cg *ContextGroup) OPTIONS(path string, handler http.HandlerFunc) {
cg.Handle("OPTIONS", path, handler)
}

type contextData struct {
route string
params map[string]string
}

func (cd contextData) Route() string {
return cd.route
}

func (cd contextData) Params() map[string]string {
if cd.params != nil {
return cd.params
}
return map[string]string{}
}

// ContextData is the information associated with
type ContextRouteData interface {
Route() string
Params() map[string]string
}

// ContextParams returns the params map associated with the given context if one exists. Otherwise, an empty map is returned.
func ContextParams(ctx context.Context) map[string]string {
if p, ok := ctx.Value(paramsContextKey).(map[string]string); ok {
return p
if p, ok := ctx.Value(routeContextKey).(ContextRouteData); ok {
return p.Params()
}
return map[string]string{}
}

// ContextData returns the full route path associated with the given context, without wildcard expansion.
func ContextData(ctx context.Context) ContextRouteData {
if p, ok := ctx.Value(routeContextKey).(ContextRouteData); ok {
return p
}
return nil
}

func AddRouteDataToContext(ctx context.Context, data ContextRouteData) context.Context {
return context.WithValue(ctx, routeContextKey, data)
}

// AddParamsToContext inserts a parameters map into a context using
// the package's internal context key. Clients of this package should
// really only use this for unit tests.
// the package's internal context key. This function is deprecated.
// Use AddRouteDataToContext instead.
func AddParamsToContext(ctx context.Context, params map[string]string) context.Context {
return context.WithValue(ctx, paramsContextKey, params)
data := contextData{
route: "",
params: params,
}
return AddRouteDataToContext(ctx, data)
}

type contextKey int

// paramsContextKey is used to retrieve a path's params map from a request's context.
const paramsContextKey contextKey = 0
// paramsContextKey and routeContextKey are used to retrieve a path's params map and route
// from a request's context.
const routeContextKey contextKey = 0
195 changes: 123 additions & 72 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package httptreemux

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
Expand All @@ -23,105 +24,155 @@ type IContextGroup interface {
}

func TestContextParams(t *testing.T) {
m := map[string]string{"id": "123"}
ctx := context.WithValue(context.Background(), paramsContextKey, m)
m := contextData{
params: map[string]string{"id": "123"},
route: "",
}

ctx := context.WithValue(context.Background(), routeContextKey, m)

params := ContextParams(ctx)
if params == nil {
t.Errorf("expected '%#v', but got '%#v'", m, params)
}

if v := params["id"]; v != "123" {
t.Errorf("expected '%s', but got '%#v'", m["id"], params["id"])
t.Errorf("expected '%s', but got '%#v'", m.params["id"], params["id"])
}
}

func TestContextData(t *testing.T) {
p := contextData{
route: "route/path",
params: map[string]string{"id": "123"},
}

ctx := context.WithValue(context.Background(), routeContextKey, p)

ctxData := ContextData(ctx)
pathValue := ctxData.Route()
if pathValue != p.route {
t.Errorf("expected '%s', but got '%s'", p, pathValue)
}

params := ctxData.Params()
if v := params["id"]; v != "123" {
t.Errorf("expected '%s', but got '%#v'", p.params["id"], params["id"])
}
}

func TestContextDataWithEmptyParams(t *testing.T) {
p := contextData{
route: "route/path",
params: nil,
}

ctx := context.WithValue(context.Background(), routeContextKey, p)
params := ContextData(ctx).Params()
if params == nil {
t.Errorf("ContextData.Params should never return nil")
}
}

func TestContextGroupMethods(t *testing.T) {
for _, scenario := range scenarios {
t.Log(scenario.description)
testContextGroupMethods(t, scenario.RequestCreator, true, false)
testContextGroupMethods(t, scenario.RequestCreator, false, false)
testContextGroupMethods(t, scenario.RequestCreator, true, true)
testContextGroupMethods(t, scenario.RequestCreator, false, true)
t.Run(scenario.description, func(t *testing.T) {
testContextGroupMethods(t, scenario.RequestCreator, true, false)
testContextGroupMethods(t, scenario.RequestCreator, false, false)
testContextGroupMethods(t, scenario.RequestCreator, true, true)
testContextGroupMethods(t, scenario.RequestCreator, false, true)
})
}
}

func testContextGroupMethods(t *testing.T, reqGen RequestCreator, headCanUseGet bool, useContextRouter bool) {
t.Logf("Running test: headCanUseGet %v, useContextRouter %v", headCanUseGet, useContextRouter)
t.Run(fmt.Sprintf("headCanUseGet %v, useContextRouter %v", headCanUseGet, useContextRouter), func(t *testing.T) {
var result string
makeHandler := func(method, expectedRoutePath string, hasParam bool) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
result = method

// Test Legacy Accessor
var v string
v, ok := ContextParams(r.Context())["param"]
if hasParam && !ok {
t.Error("missing key 'param' in context from ContextParams")
}

ctxData := ContextData(r.Context())
v, ok = ctxData.Params()["param"]
if hasParam && !ok {
t.Error("missing key 'param' in context from ContextData")
}

routePath := ctxData.Route()
if routePath != expectedRoutePath {
t.Errorf("Expected context to have route path '%s', saw %s", expectedRoutePath, routePath)
}

if headCanUseGet && (method == "GET" || v == "HEAD") {
return
}
if hasParam && v != method {
t.Errorf("invalid key 'param' in context; expected '%s' but got '%s'", method, v)
}
}
}

var result string
makeHandler := func(method string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
result = method
var router http.Handler
var rootGroup IContextGroup

v, ok := ContextParams(r.Context())["param"]
if !ok {
t.Error("missing key 'param' in context")
}
if useContextRouter {
root := NewContextMux()
root.HeadCanUseGet = headCanUseGet
t.Log(root.TreeMux.HeadCanUseGet)
router = root
rootGroup = root
} else {
root := New()
root.HeadCanUseGet = headCanUseGet
router = root
rootGroup = root.UsingContext()
}

if headCanUseGet && (method == "GET" || v == "HEAD") {
return
cg := rootGroup.NewGroup("/base").NewGroup("/user")
cg.GET("/:param", makeHandler("GET", cg.group.path+"/:param", true))
cg.POST("/:param", makeHandler("POST", cg.group.path+"/:param", true))
cg.PATCH("/PATCH", makeHandler("PATCH", cg.group.path+"/PATCH", false))
cg.PUT("/:param", makeHandler("PUT", cg.group.path+"/:param", true))
cg.Handler("DELETE", "/:param", http.HandlerFunc(makeHandler("DELETE", cg.group.path+"/:param", true)))

testMethod := func(method, expect string) {
result = ""
w := httptest.NewRecorder()
r, _ := reqGen(method, "/base/user/"+method, nil)
router.ServeHTTP(w, r)
if expect == "" && w.Code != http.StatusMethodNotAllowed {
t.Errorf("Method %s not expected to match but saw code %d", method, w.Code)
}

if v != method {
t.Errorf("invalid key 'param' in context; expected '%s' but got '%s'", method, v)
if result != expect {
t.Errorf("Method %s got result %s", method, result)
}
}
}

var router http.Handler
var rootGroup IContextGroup

if useContextRouter {
root := NewContextMux()
root.HeadCanUseGet = headCanUseGet
t.Log(root.TreeMux.HeadCanUseGet)
router = root
rootGroup = root
} else {
root := New()
root.HeadCanUseGet = headCanUseGet
router = root
rootGroup = root.UsingContext()
}

cg := rootGroup.NewGroup("/base").NewGroup("/user")
cg.GET("/:param", makeHandler("GET"))
cg.POST("/:param", makeHandler("POST"))
cg.PATCH("/:param", makeHandler("PATCH"))
cg.PUT("/:param", makeHandler("PUT"))
cg.DELETE("/:param", makeHandler("DELETE"))
testMethod("GET", "GET")
testMethod("POST", "POST")
testMethod("PATCH", "PATCH")
testMethod("PUT", "PUT")
testMethod("DELETE", "DELETE")

testMethod := func(method, expect string) {
result = ""
w := httptest.NewRecorder()
r, _ := reqGen(method, "/base/user/"+method, nil)
router.ServeHTTP(w, r)
if expect == "" && w.Code != http.StatusMethodNotAllowed {
t.Errorf("Method %s not expected to match but saw code %d", method, w.Code)
}

if result != expect {
t.Errorf("Method %s got result %s", method, result)
if headCanUseGet {
t.Log("Test implicit HEAD with HeadCanUseGet = true")
testMethod("HEAD", "GET")
} else {
t.Log("Test implicit HEAD with HeadCanUseGet = false")
testMethod("HEAD", "")
}
}

testMethod("GET", "GET")
testMethod("POST", "POST")
testMethod("PATCH", "PATCH")
testMethod("PUT", "PUT")
testMethod("DELETE", "DELETE")

if headCanUseGet {
t.Log("Test implicit HEAD with HeadCanUseGet = true")
testMethod("HEAD", "GET")
} else {
t.Log("Test implicit HEAD with HeadCanUseGet = false")
testMethod("HEAD", "")
}

cg.HEAD("/:param", makeHandler("HEAD"))
testMethod("HEAD", "HEAD")
cg.HEAD("/:param", makeHandler("HEAD", cg.group.path+"/:param", true))
testMethod("HEAD", "HEAD")
})
}

func TestNewContextGroup(t *testing.T) {
Expand Down

0 comments on commit e7c8e18

Please sign in to comment.