Skip to content

Commit

Permalink
Allow a origin validation function with context
Browse files Browse the repository at this point in the history
  • Loading branch information
dbhoot committed Feb 23, 2024
1 parent fcbd06f commit 82827c2
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 19 deletions.
42 changes: 25 additions & 17 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ import (
)

type cors struct {
allowAllOrigins bool
allowCredentials bool
allowOriginFunc func(string) bool
allowOrigins []string
normalHeaders http.Header
preflightHeaders http.Header
wildcardOrigins [][]string
optionsResponseStatusCode int
allowAllOrigins bool
allowCredentials bool
allowOriginFunc func(string) bool
allowOriginWithContextFunc func(*gin.Context, string) bool
allowOrigins []string
normalHeaders http.Header
preflightHeaders http.Header
wildcardOrigins [][]string
optionsResponseStatusCode int
}

var (
Expand Down Expand Up @@ -54,14 +55,15 @@ func newCors(config Config) *cors {
}

return &cors{
allowOriginFunc: config.AllowOriginFunc,
allowAllOrigins: config.AllowAllOrigins,
allowCredentials: config.AllowCredentials,
allowOrigins: normalize(config.AllowOrigins),
normalHeaders: generateNormalHeaders(config),
preflightHeaders: generatePreflightHeaders(config),
wildcardOrigins: config.parseWildcardRules(),
optionsResponseStatusCode: config.OptionsResponseStatusCode,
allowOriginFunc: config.AllowOriginFunc,
allowOriginWithContextFunc: config.AllowOriginWithContextFunc,
allowAllOrigins: config.AllowAllOrigins,
allowCredentials: config.AllowCredentials,
allowOrigins: normalize(config.AllowOrigins),
normalHeaders: generateNormalHeaders(config),
preflightHeaders: generatePreflightHeaders(config),
wildcardOrigins: config.parseWildcardRules(),
optionsResponseStatusCode: config.OptionsResponseStatusCode,
}
}

Expand All @@ -79,7 +81,13 @@ func (cors *cors) applyCors(c *gin.Context) {
return
}

if !cors.validateOrigin(origin) {
if cors.allowOriginWithContextFunc != nil {
if !cors.allowOriginWithContextFunc(c, origin) {
c.AbortWithStatus(http.StatusForbidden)
return
}

} else if !cors.validateOrigin(origin) {
c.AbortWithStatus(http.StatusForbidden)
return
}
Expand Down
10 changes: 8 additions & 2 deletions cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ type Config struct {
// set, the content of AllowOrigins is ignored.
AllowOriginFunc func(origin string) bool

// The same as AllowOriginFunc but allows access to the entire request context
AllowOriginWithContextFunc func(c *gin.Context, origin string) bool

// AllowMethods is a list of methods the client is allowed to use with
// cross-domain requests. Default value is simple methods (GET, POST, PUT, PATCH, DELETE, HEAD, and OPTIONS)
AllowMethods []string
Expand Down Expand Up @@ -102,12 +105,15 @@ func (c Config) validateAllowedSchemas(origin string) bool {

// Validate is check configuration of user defined.
func (c Config) Validate() error {
if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) {
if c.AllowAllOrigins && (c.AllowOriginFunc != nil || c.AllowOriginWithContextFunc != nil || len(c.AllowOrigins) > 0) {
return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowOrigins is not needed")
}
if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 {
if !c.AllowAllOrigins && c.AllowOriginFunc == nil && c.AllowOriginWithContextFunc == nil && len(c.AllowOrigins) == 0 {
return errors.New("conflict settings: all origins disabled")
}
if c.AllowOriginFunc != nil && c.AllowOriginWithContextFunc != nil {
return errors.New("conflict settings: Both original validation functions are defined")
}
for _, origin := range c.AllowOrigins {
if !strings.Contains(origin, "*") && !c.validateAllowedSchemas(origin) {
return errors.New("bad origin: origins must contain '*' or include " + strings.Join(c.getAllowedSchemas(), ","))
Expand Down
2 changes: 2 additions & 0 deletions cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ func TestGeneratePreflightHeaders_MaxAge(t *testing.T) {
}

func TestValidateOrigin(t *testing.T) {
// review the below for adding a testing context
//https://pkg.go.dev/github.com/gin-gonic/gin#CreateTestContextOnly
cors := newCors(Config{
AllowAllOrigins: true,
})
Expand Down

0 comments on commit 82827c2

Please sign in to comment.