Skip to content

Commit

Permalink
Cors headers (#2021)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aaron Salvo authored and jefferai committed Jun 17, 2017
1 parent bd7cbe8 commit 362227c
Show file tree
Hide file tree
Showing 13 changed files with 580 additions and 6 deletions.
56 changes: 56 additions & 0 deletions api/sys_config_cors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package api

func (c *Sys) CORSStatus() (*CORSResponse, error) {
r := c.c.NewRequest("GET", "/v1/sys/config/cors")
resp, err := c.c.RawRequest(r)
if err != nil {
return nil, err
}
defer resp.Body.Close()

var result CORSResponse
err = resp.DecodeJSON(&result)
return &result, err
}

func (c *Sys) ConfigureCORS(req *CORSRequest) (*CORSResponse, error) {
r := c.c.NewRequest("PUT", "/v1/sys/config/cors")
if err := r.SetJSONBody(req); err != nil {
return nil, err
}

resp, err := c.c.RawRequest(r)
if err != nil {
return nil, err
}
defer resp.Body.Close()

var result CORSResponse
err = resp.DecodeJSON(&result)
return &result, err
}

func (c *Sys) DisableCORS() (*CORSResponse, error) {
r := c.c.NewRequest("DELETE", "/v1/sys/config/cors")

resp, err := c.c.RawRequest(r)
if err != nil {
return nil, err
}
defer resp.Body.Close()

var result CORSResponse
err = resp.DecodeJSON(&result)
return &result, err

}

type CORSRequest struct {
AllowedOrigins string `json:"allowed_origins"`
Enabled bool `json:"enabled"`
}

type CORSResponse struct {
AllowedOrigins string `json:"allowed_origins"`
Enabled bool `json:"enabled"`
}
1 change: 0 additions & 1 deletion cli/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory {
Meta: *metaPtr,
}, nil
},

"server": func() (cli.Command, error) {
return &command.ServerCommand{
Meta: *metaPtr,
Expand Down
68 changes: 68 additions & 0 deletions http/cors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package http

import (
"net/http"
"strings"

"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/vault"
)

var preflightHeaders = map[string]string{
"Access-Control-Allow-Headers": "*",
"Access-Control-Max-Age": "300",
}

var allowedMethods = []string{
http.MethodDelete,
http.MethodGet,
http.MethodOptions,
http.MethodPost,
http.MethodPut,
"LIST", // LIST is not an official HTTP method, but Vault supports it.
}

func wrapCORSHandler(h http.Handler, core *vault.Core) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
corsConf := core.CORSConfig()

origin := req.Header.Get("Origin")
requestMethod := req.Header.Get("Access-Control-Request-Method")

// If CORS is not enabled or if no Origin header is present (i.e. the request
// is from the Vault CLI. A browser will always send an Origin header), then
// just return a 204.
if !corsConf.IsEnabled() || origin == "" {
h.ServeHTTP(w, req)
return
}

// Return a 403 if the origin is not
// allowed to make cross-origin requests.
if !corsConf.IsValidOrigin(origin) {
w.WriteHeader(http.StatusForbidden)
return
}

if req.Method == http.MethodOptions && !strutil.StrListContains(allowedMethods, requestMethod) {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}

w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Vary", "Origin")

// apply headers for preflight requests
if req.Method == http.MethodOptions {
w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ","))

for k, v := range preflightHeaders {
w.Header().Set(k, v)
}
return
}

h.ServeHTTP(w, req)
return
})
}
3 changes: 2 additions & 1 deletion http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,11 @@ func Handler(core *vault.Core) http.Handler {

// Wrap the handler in another handler to trigger all help paths.
helpWrappedHandler := wrapHelpHandler(mux, core)
corsWrappedHandler := wrapCORSHandler(helpWrappedHandler, core)

// Wrap the help wrapped handler with another layer with a generic
// handler
genericWrappedHandler := wrapGenericHandler(helpWrappedHandler)
genericWrappedHandler := wrapGenericHandler(corsWrappedHandler)

return genericWrappedHandler
}
Expand Down
81 changes: 81 additions & 0 deletions http/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,87 @@ import (
"github.com/hashicorp/vault/vault"
)

func TestHandler_cors(t *testing.T) {
core, _, _ := vault.TestCoreUnsealed(t)
ln, addr := TestServer(t, core)
defer ln.Close()

// Enable CORS and allow from any origin for testing.
corsConfig := core.CORSConfig()
err := corsConfig.Enable([]string{addr})
if err != nil {
t.Fatalf("Error enabling CORS: %s", err)
}

req, err := http.NewRequest(http.MethodOptions, addr+"/v1/sys/seal-status", nil)
if err != nil {
t.Fatalf("err: %s", err)
}
req.Header.Set("Origin", "BAD ORIGIN")

// Requests from unacceptable origins will be rejected with a 403.
client := cleanhttp.DefaultClient()
resp, err := client.Do(req)
if err != nil {
t.Fatalf("err: %s", err)
}

if resp.StatusCode != http.StatusForbidden {
t.Fatalf("Bad status:\nexpected: 403 Forbidden\nactual: %s", resp.Status)
}

//
// Test preflight requests
//

// Set a valid origin
req.Header.Set("Origin", addr)

// Server should NOT accept arbitrary methods.
req.Header.Set("Access-Control-Request-Method", "FOO")

client = cleanhttp.DefaultClient()
resp, err = client.Do(req)
if err != nil {
t.Fatalf("err: %s", err)
}

// Fail if an arbitrary method is accepted.
if resp.StatusCode != http.StatusMethodNotAllowed {
t.Fatalf("Bad status:\nexpected: 405 Method Not Allowed\nactual: %s", resp.Status)
}

// Server SHOULD accept acceptable methods.
req.Header.Set("Access-Control-Request-Method", http.MethodPost)

client = cleanhttp.DefaultClient()
resp, err = client.Do(req)
if err != nil {
t.Fatalf("err: %s", err)
}

//
// Test that the CORS headers are applied correctly.
//
expHeaders := map[string]string{
"Access-Control-Allow-Origin": addr,
"Access-Control-Allow-Headers": "*",
"Access-Control-Max-Age": "300",
"Vary": "Origin",
}

for expHeader, expected := range expHeaders {
actual := resp.Header.Get(expHeader)
if actual == "" {
t.Fatalf("bad:\nHeader: %#v was not on response.", expHeader)
}

if actual != expected {
t.Fatalf("bad:\nExpected: %#v\nActual: %#v\n", expected, actual)
}
}
}

func TestHandler_CacheControlNoStore(t *testing.T) {
core, _, token := vault.TestCoreUnsealed(t)
ln, addr := TestServer(t, core)
Expand Down
6 changes: 6 additions & 0 deletions http/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"regexp"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -55,6 +56,11 @@ func testHttpData(t *testing.T, method string, token string, addr string, body i
t.Fatalf("err: %s", err)
}

// Get the address of the local listener in order to attach it to an Origin header.
// This will allow for the testing of requests that require CORS, without using a browser.
hostURLRegexp, _ := regexp.Compile("http[s]?://.+:[0-9]+")
req.Header.Set("Origin", hostURLRegexp.FindString(addr))

req.Header.Set("Content-Type", "application/json")

if len(token) != 0 {
Expand Down
1 change: 1 addition & 0 deletions http/logical.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
op = logical.UpdateOperation
case "LIST":
op = logical.ListOperation
case "OPTIONS":
default:
return nil, http.StatusMethodNotAllowed, nil
}
Expand Down
15 changes: 15 additions & 0 deletions vault/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,9 @@ type Core struct {
// The grpc forwarding client
rpcForwardingClient *forwardingClient

// CORS Information
corsConfig *CORSConfig

// replicationState keeps the current replication state cached for quick
// lookup
replicationState consts.ReplicationState
Expand Down Expand Up @@ -447,6 +450,7 @@ func NewCore(conf *CoreConfig) (*Core, error) {
clusterName: conf.ClusterName,
clusterListenerShutdownCh: make(chan struct{}),
clusterListenerShutdownSuccessCh: make(chan struct{}),
corsConfig: &CORSConfig{},
clusterPeerClusterAddrsCache: cache.New(3*heartbeatInterval, time.Second),
enableMlock: !conf.DisableMlock,
}
Expand Down Expand Up @@ -555,6 +559,11 @@ func (c *Core) Shutdown() error {
return c.sealInternal()
}

// CORSConfig returns the current CORS configuration
func (c *Core) CORSConfig() *CORSConfig {
return c.corsConfig
}

// LookupToken returns the properties of the token from the token store. This
// is particularly useful to fetch the accessor of the client token and get it
// populated in the logical request along with the client token. The accessor
Expand Down Expand Up @@ -1291,6 +1300,9 @@ func (c *Core) postUnseal() (retErr error) {
if err := c.setupPolicyStore(); err != nil {
return err
}
if err := c.loadCORSConfig(); err != nil {
return err
}
if err := c.loadCredentials(); err != nil {
return err
}
Expand Down Expand Up @@ -1356,6 +1368,9 @@ func (c *Core) preSeal() error {
if err := c.teardownPolicyStore(); err != nil {
result = multierror.Append(result, errwrap.Wrapf("error tearing down policy store: {{err}}", err))
}
if err := c.saveCORSConfig(); err != nil {
result = multierror.Append(result, errwrap.Wrapf("error tearing down CORS config: {{err}}", err))
}
if err := c.stopRollback(); err != nil {
result = multierror.Append(result, errwrap.Wrapf("error stopping rollback: {{err}}", err))
}
Expand Down
Loading

0 comments on commit 362227c

Please sign in to comment.