Skip to content

Commit

Permalink
Fix up CORS.
Browse files Browse the repository at this point in the history
Ref #2021
  • Loading branch information
jefferai committed Jun 17, 2017
1 parent 362227c commit 27e584c
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 168 deletions.
3 changes: 2 additions & 1 deletion command/rekey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ func TestRekey_init_pgp(t *testing.T) {
MaxLeaseTTLVal: time.Hour * 24 * 32,
},
}
sysBackend, err := vault.NewSystemBackend(core, bc)
sysBE := vault.NewSystemBackend(core)
sysBackend, err := sysBE.Backend.Setup(bc)
if err != nil {
t.Fatal(err)
}
Expand Down
10 changes: 5 additions & 5 deletions vault/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,11 +450,13 @@ func NewCore(conf *CoreConfig) (*Core, error) {
clusterName: conf.ClusterName,
clusterListenerShutdownCh: make(chan struct{}),
clusterListenerShutdownSuccessCh: make(chan struct{}),
corsConfig: &CORSConfig{},
clusterPeerClusterAddrsCache: cache.New(3*heartbeatInterval, time.Second),
enableMlock: !conf.DisableMlock,
}

// Load CORS config and provide core
c.corsConfig = &CORSConfig{core: c}

// Wrap the physical backend in a cache layer if enabled and not already wrapped
if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache {
c.physical = physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger)
Expand Down Expand Up @@ -513,7 +515,8 @@ func NewCore(conf *CoreConfig) (*Core, error) {
}
logicalBackends["cubbyhole"] = CubbyholeBackendFactory
logicalBackends["system"] = func(config *logical.BackendConfig) (logical.Backend, error) {
return NewSystemBackend(c, config)
b := NewSystemBackend(c)
return b.Backend.Setup(config)
}
c.logicalBackends = logicalBackends

Expand Down Expand Up @@ -1368,9 +1371,6 @@ func (c *Core) preSeal() error {
if err := c.teardownPolicyStore(); err != nil {
result = multierror.Append(result, errwrap.Wrapf("error tearing down policy store: {{err}}", err))
}
if err := c.saveCORSConfig(); err != nil {
result = multierror.Append(result, errwrap.Wrapf("error tearing down CORS config: {{err}}", err))
}
if err := c.stopRollback(); err != nil {
result = multierror.Append(result, errwrap.Wrapf("error stopping rollback: {{err}}", err))
}
Expand Down
59 changes: 39 additions & 20 deletions vault/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,36 @@ import (
"errors"
"fmt"
"sync"
"sync/atomic"

"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
)

var errCORSNotConfigured = errors.New("CORS is not configured")
const (
CORSDisabled uint32 = iota
CORSEnabled
)

// CORSConfig stores the state of the CORS configuration.
type CORSConfig struct {
sync.RWMutex
Enabled bool `json:"enabled"`
AllowedOrigins []string `json:"allowed_origins"`
sync.RWMutex `json:"-"`
core *Core
Enabled uint32 `json:"enabled"`
AllowedOrigins []string `json:"allowed_origins,omitempty"`
}

func (c *Core) saveCORSConfig() error {
view := c.systemBarrierView.SubView("config/")

entry, err := logical.StorageEntryJSON("cors", c.corsConfig)
localConfig := &CORSConfig{
Enabled: atomic.LoadUint32(&c.corsConfig.Enabled),
}
c.corsConfig.RLock()
localConfig.AllowedOrigins = c.corsConfig.AllowedOrigins
c.corsConfig.RUnlock()

entry, err := logical.StorageEntryJSON("cors", localConfig)
if err != nil {
return fmt.Errorf("failed to create CORS config entry: %v", err)
}
Expand All @@ -33,6 +45,7 @@ func (c *Core) saveCORSConfig() error {
return nil
}

// This should only be called with the core state lock held for writing
func (c *Core) loadCORSConfig() error {
view := c.systemBarrierView.SubView("config/")

Expand All @@ -45,10 +58,14 @@ func (c *Core) loadCORSConfig() error {
return nil
}

err = out.DecodeJSON(c.corsConfig)
newConfig := new(CORSConfig)
err = out.DecodeJSON(newConfig)
if err != nil {
return err
}
newConfig.core = c

c.corsConfig = newConfig

return nil
}
Expand All @@ -65,38 +82,40 @@ func (c *CORSConfig) Enable(urls []string) error {
}

c.Lock()
defer c.Unlock()

c.AllowedOrigins = urls
c.Enabled = true
c.Unlock()

return nil
atomic.StoreUint32(&c.Enabled, CORSEnabled)

return c.core.saveCORSConfig()
}

// IsEnabled returns the value of CORSConfig.isEnabled
func (c *CORSConfig) IsEnabled() bool {
c.RLock()
defer c.RUnlock()

return c.Enabled
return atomic.LoadUint32(&c.Enabled) == CORSEnabled
}

// Disable sets CORS to disabled and clears the allowed origins
func (c *CORSConfig) Disable() {
func (c *CORSConfig) Disable() error {
atomic.StoreUint32(&c.Enabled, CORSDisabled)
c.Lock()
defer c.Unlock()

c.Enabled = false
c.AllowedOrigins = []string{}
c.AllowedOrigins = []string(nil)
c.Unlock()
return c.core.saveCORSConfig()
}

// IsValidOrigin determines if the origin of the request is allowed to make
// cross-origin requests based on the CORSConfig.
func (c *CORSConfig) IsValidOrigin(origin string) bool {
// If we aren't enabling CORS then all origins are valid
if !c.IsEnabled() {
return true
}

c.RLock()
defer c.RUnlock()

if c.AllowedOrigins == nil {
if len(c.AllowedOrigins) == 0 {
return false
}

Expand Down
50 changes: 25 additions & 25 deletions vault/logical_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ var (
}
)

func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backend, error) {
func NewSystemBackend(core *Core) *SystemBackend {
b := &SystemBackend{
Core: core,
}
Expand All @@ -62,7 +62,7 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen
"replication/primary/secondary-token",
"replication/reindex",
"rotate",
"config/*",
"config/cors",
"config/auditing/*",
"plugins/catalog/*",
"revoke-prefix/*",
Expand Down Expand Up @@ -110,7 +110,7 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen
},
"allowed_origins": &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: "A comma-separated list of origins that may make cross-origin requests.",
Description: "A comma-separated string or array of strings indicating origins that may make cross-origin requests.",
},
},

Expand Down Expand Up @@ -823,50 +823,50 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen

b.Backend.Invalidate = b.invalidate

return b.Backend.Setup(config)
return b
}

// SystemBackend implements logical.Backend and is used to interact with
// the core of the system. This backend is hardcoded to exist at the "sys"
// prefix. Conceptually it is similar to procfs on Linux.
type SystemBackend struct {
Core *Core
Backend *framework.Backend
*framework.Backend
Core *Core
}

// handleCORSRead returns the current CORS configuration
func (b *SystemBackend) handleCORSRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
corsConf := b.Core.corsConfig
if corsConf == nil {
return nil, errCORSNotConfigured
}

return &logical.Response{
enabled := corsConf.IsEnabled()

resp := &logical.Response{
Data: map[string]interface{}{
"enabled": corsConf.Enabled,
"allowed_origins": strings.Join(corsConf.AllowedOrigins, ","),
"enabled": enabled,
},
}, nil
}

if enabled {
corsConf.RLock()
resp.Data["allowed_origins"] = corsConf.AllowedOrigins
corsConf.RUnlock()
}

return resp, nil
}

// handleCORSUpdate sets the list of origins that are allowed
// to make cross-origin requests and sets the CORS enabled flag to true
// handleCORSUpdate sets the list of origins that are allowed to make
// cross-origin requests and sets the CORS enabled flag to true
func (b *SystemBackend) handleCORSUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
origins := d.Get("allowed_origins").([]string)

err := b.Core.corsConfig.Enable(origins)
if err != nil {
return nil, err
}

return nil, nil
return nil, b.Core.corsConfig.Enable(origins)
}

// handleCORSDelete clears the allowed origins and sets the CORS enabled flag to false
// handleCORSDelete clears the allowed origins and sets the CORS enabled flag
// to false
func (b *SystemBackend) handleCORSDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
b.Core.CORSConfig().Disable()

return nil, nil
return nil, b.Core.corsConfig.Disable()
}

func (b *SystemBackend) handleTidyLeases(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
Expand Down
22 changes: 14 additions & 8 deletions vault/logical_system_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ func TestSystemBackend_RootPaths(t *testing.T) {

func TestSystemConfigCORS(t *testing.T) {
b := testSystemBackend(t)
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "")
b.(*SystemBackend).Core.systemBarrierView = view

req := logical.TestRequest(t, logical.UpdateOperation, "config/cors")
req.Data["allowed_origins"] = "http://www.example.com"
Expand All @@ -60,7 +63,7 @@ func TestSystemConfigCORS(t *testing.T) {
expected := &logical.Response{
Data: map[string]interface{}{
"enabled": true,
"allowed_origins": "http://www.example.com",
"allowed_origins": []string{"http://www.example.com"},
},
}

Expand All @@ -71,7 +74,7 @@ func TestSystemConfigCORS(t *testing.T) {
}

if !reflect.DeepEqual(actual, expected) {
t.Fatalf("UPDATE FAILED -- bad: %#v", actual)
t.Fatalf("bad: %#v", actual)
}

req = logical.TestRequest(t, logical.DeleteOperation, "config/cors")
Expand All @@ -88,8 +91,7 @@ func TestSystemConfigCORS(t *testing.T) {

expected = &logical.Response{
Data: map[string]interface{}{
"enabled": false,
"allowed_origins": "",
"enabled": false,
},
}

Expand Down Expand Up @@ -980,7 +982,8 @@ func TestSystemBackend_revokePrefixAuth(t *testing.T) {
MaxLeaseTTLVal: time.Hour * 24 * 32,
},
}
b, err := NewSystemBackend(core, bc)
be := NewSystemBackend(core)
b, err := be.Backend.Setup(bc)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1043,7 +1046,8 @@ func TestSystemBackend_revokePrefixAuth_origUrl(t *testing.T) {
MaxLeaseTTLVal: time.Hour * 24 * 32,
},
}
b, err := NewSystemBackend(core, bc)
be := NewSystemBackend(core)
b, err := be.Backend.Setup(bc)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1578,7 +1582,8 @@ func testSystemBackend(t *testing.T) logical.Backend {
},
}

b, err := NewSystemBackend(c, bc)
b := NewSystemBackend(c)
_, err := b.Backend.Setup(bc)
if err != nil {
t.Fatal(err)
}
Expand All @@ -1596,7 +1601,8 @@ func testCoreSystemBackend(t *testing.T) (*Core, logical.Backend, string) {
},
}

b, err := NewSystemBackend(c, bc)
b := NewSystemBackend(c)
_, err := b.Backend.Setup(bc)
if err != nil {
t.Fatal(err)
}
Expand Down
Loading

0 comments on commit 27e584c

Please sign in to comment.