Skip to content

Commit

Permalink
go-aah/aah#126 router updates for websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
jeevatkm committed Apr 24, 2018
1 parent 943bc56 commit 10441cf
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 108 deletions.
20 changes: 10 additions & 10 deletions cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,17 @@ var (
// Spec: https://www.w3.org/TR/cors/
// Friendly Read: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
type CORS struct {
AllowOrigins []string
AllowMethods []string
AllowHeaders []string
ExposeHeaders []string
AllowCredentials bool
MaxAge string

maxAgeStr string
allowAllOrigins bool
allowAllMethods bool
allowAllHeaders bool
allowAllOrigins bool
allowAllMethods bool
allowAllHeaders bool

MaxAge string
maxAgeStr string
AllowOrigins []string
AllowMethods []string
AllowHeaders []string
ExposeHeaders []string
}

// AddOrigins method adds the given origin into allow origin list.
Expand Down
2 changes: 2 additions & 0 deletions cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package router

import (
"io/ioutil"
"os"
"path/filepath"
"testing"
Expand All @@ -17,6 +18,7 @@ import (

func TestRouterCORS1(t *testing.T) {
_ = log.SetLevel("TRACE")
log.SetWriter(ioutil.Discard)
wd, _ := os.Getwd()
appCfg, _ := config.ParseString("")
router := New(filepath.Join(wd, "testdata", "routes-cors-1.conf"), appCfg)
Expand Down
21 changes: 11 additions & 10 deletions domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package router
import (
"errors"
"fmt"
"net/http"
"net/url"
"path"
"strings"
Expand All @@ -22,24 +23,24 @@ import (

// Domain is used to hold domain related routes and it's route configuration
type Domain struct {
Name string
Host string
Port string
IsSubDomain bool
MethodNotAllowed bool
RedirectTrailingSlash bool
AutoOptions bool
CORSEnabled bool
Name string
Host string
Port string
DefaultAuth string
CORS *CORS
CORSEnabled bool
trees map[string]*node
routes map[string]*Route
}

// Lookup method finds a route, path parameters, redirect trailing slash
// indicator for given `ahttp.Request` by domain and request URI
// otherwise returns nil and false.
func (d *Domain) Lookup(req *ahttp.Request) (*Route, ahttp.PathParams, bool) {
// Lookup method looks up route if found it returns route, path parameters,
// redirect trailing slash indicator for given `ahttp.Request` by domain
// and request URI otherwise returns nil and false.
func (d *Domain) Lookup(req *http.Request) (*Route, ahttp.PathParams, bool) {
// HTTP method override support
overrideMethod := req.Header.Get(ahttp.HeaderXHTTPMethodOverride)
if !ess.IsStrEmpty(overrideMethod) && req.Method == ahttp.MethodPost {
Expand All @@ -52,7 +53,7 @@ func (d *Domain) Lookup(req *ahttp.Request) (*Route, ahttp.PathParams, bool) {
return nil, nil, false
}

route, pathParams, rts, err := tree.find(req.Path)
route, pathParams, rts, err := tree.find(req.URL.Path)
if route != nil && err == nil {
return route.(*Route), pathParams, rts
} else if rts { // possible Redirect Trailing Slash
Expand Down Expand Up @@ -243,7 +244,7 @@ func (d *Domain) key() string {
return strings.ToLower(d.Host + ":" + d.Port)
}

func (d *Domain) lookupRouteTree(req *ahttp.Request) (*node, bool) {
func (d *Domain) lookupRouteTree(req *http.Request) (*node, bool) {
// get route tree for request method
if tree, found := d.trees[req.Method]; found {
return tree, true
Expand Down
4 changes: 2 additions & 2 deletions radix_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ var (
type nodeType uint8

type node struct {
path string
wildChild bool
nType nodeType
maxParams uint8
indices string
priority uint32
path string
indices string
edges []*node
value interface{}
}
Expand Down
1 change: 0 additions & 1 deletion radix_tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,6 @@ func TestTreeWildcardConflictEx(t *testing.T) {
}

err := tree.add(conflict.route, "conflict.route")
fmt.Println(err)

if !regexp.MustCompile(fmt.Sprintf("'%s' in new path .* conflicts with existing wildcard '%s' in existing prefix '%s'", conflict.segPath, conflict.existSegPath, conflict.existPath)).MatchString(err.Error()) {
t.Fatalf("invalid wildcard conflict error (%v)", err.Error())
Expand Down
89 changes: 53 additions & 36 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ import (
"aahframework.org/log.v0"
)

const wildcardSubdomainPrefix = "*."
const (
wildcardSubdomainPrefix = "*."
methodWebSocket = "WS"
)

var (
// HTTPMethodActionMap is default Controller Action name for corresponding
Expand Down Expand Up @@ -158,14 +161,26 @@ func (r *Router) RegisteredActions() map[string]map[string]uint8 {
methods := map[string]map[string]uint8{}
for _, d := range r.Domains {
for _, route := range d.routes {
if route.IsStatic {
if route.IsStatic || route.Method == methodWebSocket {
continue
}

addRegisteredAction(methods, route)
}
}
return methods
}

// RegisteredWSActions method returns all the WebSocket name and it's actions
// configured in the "routes.conf".
func (r *Router) RegisteredWSActions() map[string]map[string]uint8 {
methods := map[string]map[string]uint8{}
for _, d := range r.Domains {
for _, route := range d.routes {
if route.Method == methodWebSocket {
addRegisteredAction(methods, route)
}
}
}
return methods
}

Expand Down Expand Up @@ -263,8 +278,8 @@ func (r *Router) processRoutesConfig() (err error) {
if !ess.IsStrEmpty(dr.ParentName) {
parentInfo = fmt.Sprintf("(parent: %s)", dr.ParentName)
}
log.Tracef("Route Name: %v %v, Path: %v, Method: %v, Controller: %v, Action: %v, Auth: %v, MaxBodySize: %v\nCORS: [%v]\nValidation Rules:%v\n",
dr.Name, parentInfo, dr.Path, dr.Method, dr.Controller, dr.Action, dr.Auth, dr.MaxBodySize,
log.Tracef("Route Name: %v %v, Path: %v, Method: %v, Target: %v, Action: %v, Auth: %v, MaxBodySize: %v\nCORS: [%v]\nValidation Rules:%v\n",
dr.Name, parentInfo, dr.Path, dr.Method, dr.Target, dr.Action, dr.Auth, dr.MaxBodySize,
dr.CORS, dr.validationRules)
}
}
Expand Down Expand Up @@ -342,33 +357,32 @@ func (r *Router) processRoutes(domain *Domain, domainCfg *config.Config) error {

// Route holds the single route details.
type Route struct {
IsAntiCSRFCheck bool
IsStatic bool
ListDir bool
MaxBodySize int64
Name string
Path string
Method string
Controller string
Target string
Action string
ParentName string
Auth string
MaxBodySize int64
IsAntiCSRFCheck bool
Dir string
File string
CORS *CORS

// static route fields in-addition to above
IsStatic bool
Dir string
File string
ListDir bool

validationRules map[string]string
}

type parentRouteInfo struct {
ParentName string
PrefixPath string
Controller string
Auth string
CORS *CORS
CORSEnabled bool
ParentName string
PrefixPath string
Target string
Auth string
AntiCSRFCheck bool
CORS *CORS
CORSEnabled bool
}

// IsDir method returns true if serving directory otherwise false.
Expand Down Expand Up @@ -434,21 +448,23 @@ func parseRoutesSection(cfg *config.Config, routeInfo *parentRouteInfo) (routes
}
}

// getting 'method', default to GET, if method not found
routeMethod := strings.ToUpper(cfg.StringDefault(routeName+".method", ahttp.MethodGet))

// check child routes exists
notToSkip := true
if cfg.IsExists(routeName + ".routes") {
if !cfg.IsExists(routeName+".action") || !cfg.IsExists(routeName+".controller") {
if !cfg.IsExists(routeName+".action") &&
(!cfg.IsExists(routeName+".controller") || !cfg.IsExists(routeName+".websocket")) {
notToSkip = false
}
}

// getting 'method', default to GET, if method not found
routeMethod := strings.ToUpper(cfg.StringDefault(routeName+".method", ahttp.MethodGet))

// getting 'controller'
routeController := cfg.StringDefault(routeName+".controller", routeInfo.Controller)
if ess.IsStrEmpty(routeController) && notToSkip {
err = fmt.Errorf("'%v.controller' key is missing", routeName)
// getting 'target' info for e.g.: controller, websocket
routeTarget := cfg.StringDefault(routeName+".controller",
cfg.StringDefault(routeName+".websocket", routeInfo.Target))
if ess.IsStrEmpty(routeTarget) && notToSkip {
err = fmt.Errorf("'%v.controller' or '%v.websocket' key is missing", routeName, routeName)
return
}

Expand All @@ -471,7 +487,7 @@ func parseRoutesSection(cfg *config.Config, routeInfo *parentRouteInfo) (routes
}

// getting Anti-CSRF check value, GitHub go-aah/aah#115
routeAntiCSRFCheck := cfg.BoolDefault(routeName+".anti_csrf_check", true)
routeAntiCSRFCheck := cfg.BoolDefault(routeName+".anti_csrf_check", routeInfo.AntiCSRFCheck)

// CORS
var cors *CORS
Expand All @@ -491,7 +507,7 @@ func parseRoutesSection(cfg *config.Config, routeInfo *parentRouteInfo) (routes
Name: routeName,
Path: actualRoutePath,
Method: strings.TrimSpace(m),
Controller: routeController,
Target: routeTarget,
Action: routeAction,
ParentName: routeInfo.ParentName,
Auth: routeAuth,
Expand All @@ -506,12 +522,13 @@ func parseRoutesSection(cfg *config.Config, routeInfo *parentRouteInfo) (routes
// loading child routes
if childRoutes, found := cfg.GetSubConfig(routeName + ".routes"); found {
croutes, er := parseRoutesSection(childRoutes, &parentRouteInfo{
ParentName: routeName,
PrefixPath: routePath,
Controller: routeController,
Auth: routeAuth,
CORS: cors,
CORSEnabled: routeInfo.CORSEnabled,
ParentName: routeName,
PrefixPath: routePath,
Target: routeTarget,
Auth: routeAuth,
AntiCSRFCheck: routeAntiCSRFCheck,
CORS: cors,
CORSEnabled: routeInfo.CORSEnabled,
})
if er != nil {
err = er
Expand Down
Loading

0 comments on commit 10441cf

Please sign in to comment.