Skip to content

Commit

Permalink
Audit: remove any race to read headers (#26155)
Browse files Browse the repository at this point in the history
* adjust code to prevent any data race in reading audited headers

* header tests

* Comment improvement

* make fmt 'fix' for unrelated file
  • Loading branch information
Peter Wilson authored Mar 26, 2024
1 parent 1885f16 commit 54e19c5
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 49 deletions.
5 changes: 2 additions & 3 deletions command/operator_diagnose.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@ import (
"sync"
"time"

"github.com/hashicorp/vault/vault/seal"

"github.com/hashicorp/cli"
"github.com/hashicorp/consul/api"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-kms-wrapping/entropy/v2"
"github.com/hashicorp/go-secure-stdlib/reloadutil"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/go-uuid"
cserver "github.com/hashicorp/vault/command/server"
"github.com/hashicorp/vault/helper/constants"
"github.com/hashicorp/vault/helper/metricsutil"
Expand All @@ -35,6 +33,7 @@ import (
"github.com/hashicorp/vault/vault"
"github.com/hashicorp/vault/vault/diagnose"
"github.com/hashicorp/vault/vault/hcp_link"
"github.com/hashicorp/vault/vault/seal"
"github.com/hashicorp/vault/version"
"github.com/posener/complete"
"golang.org/x/term"
Expand Down
4 changes: 2 additions & 2 deletions vault/audit_broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ func testAuditBackend(t *testing.T, path string, config map[string]string) audit
t.Helper()

headersCfg := &AuditedHeadersConfig{
Headers: make(map[string]*auditedHeaderSettings),
view: nil,
headerSettings: make(map[string]*auditedHeaderSettings),
view: nil,
}

view := &logical.InmemStorage{}
Expand Down
61 changes: 47 additions & 14 deletions vault/audited_headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ type auditedHeaderSettings struct {
// AuditedHeadersConfig is used by the Audit Broker to write only approved
// headers to the audit logs. It uses a BarrierView to persist the settings.
type AuditedHeadersConfig struct {
// Headers stores the current headers that should be audited, and their settings.
Headers map[string]*auditedHeaderSettings
// headerSettings stores the current headers that should be audited, and their settings.
headerSettings map[string]*auditedHeaderSettings

// view is the barrier view which should be used to access underlying audit header config data.
view *BarrierView
Expand All @@ -54,11 +54,44 @@ func NewAuditedHeadersConfig(view *BarrierView) (*AuditedHeadersConfig, error) {
// This should be the only place where the AuditedHeadersConfig struct is initialized.
// Store the view so that we can reload headers when we 'invalidate'.
return &AuditedHeadersConfig{
view: view,
Headers: make(map[string]*auditedHeaderSettings),
view: view,
headerSettings: make(map[string]*auditedHeaderSettings),
}, nil
}

// header attempts to retrieve a copy of the settings associated with the specified header.
// The second boolean return parameter indicates whether the header existed in configuration,
// it should be checked as when 'false' the returned settings will have the default values.
func (a *AuditedHeadersConfig) header(name string) (auditedHeaderSettings, bool) {
a.RLock()
defer a.RUnlock()

var s auditedHeaderSettings
v, ok := a.headerSettings[strings.ToLower(name)]

if ok {
s.HMAC = v.HMAC
}

return s, ok
}

// headers returns all existing headers along with a copy of their current settings.
func (a *AuditedHeadersConfig) headers() map[string]auditedHeaderSettings {
a.RLock()
defer a.RUnlock()

// We know how many entries the map should have.
headers := make(map[string]auditedHeaderSettings, len(a.headerSettings))

// Clone the headers
for name, setting := range a.headerSettings {
headers[name] = auditedHeaderSettings{HMAC: setting.HMAC}
}

return headers
}

// add adds or overwrites a header in the config and updates the barrier view
// NOTE: add will acquire a write lock in order to update the underlying headers.
func (a *AuditedHeadersConfig) add(ctx context.Context, header string, hmac bool) error {
Expand All @@ -70,12 +103,12 @@ func (a *AuditedHeadersConfig) add(ctx context.Context, header string, hmac bool
a.Lock()
defer a.Unlock()

if a.Headers == nil {
a.Headers = make(map[string]*auditedHeaderSettings, 1)
if a.headerSettings == nil {
a.headerSettings = make(map[string]*auditedHeaderSettings, 1)
}

a.Headers[strings.ToLower(header)] = &auditedHeaderSettings{hmac}
entry, err := logical.StorageEntryJSON(auditedHeadersEntry, a.Headers)
a.headerSettings[strings.ToLower(header)] = &auditedHeaderSettings{hmac}
entry, err := logical.StorageEntryJSON(auditedHeadersEntry, a.headerSettings)
if err != nil {
return fmt.Errorf("failed to persist audited headers config: %w", err)
}
Expand All @@ -99,12 +132,12 @@ func (a *AuditedHeadersConfig) remove(ctx context.Context, header string) error
defer a.Unlock()

// Nothing to delete
if len(a.Headers) == 0 {
if len(a.headerSettings) == 0 {
return nil
}

delete(a.Headers, strings.ToLower(header))
entry, err := logical.StorageEntryJSON(auditedHeadersEntry, a.Headers)
delete(a.headerSettings, strings.ToLower(header))
entry, err := logical.StorageEntryJSON(auditedHeadersEntry, a.headerSettings)
if err != nil {
return fmt.Errorf("failed to persist audited headers config: %w", err)
}
Expand Down Expand Up @@ -145,7 +178,7 @@ func (a *AuditedHeadersConfig) invalidate(ctx context.Context) error {
lowerHeaders[strings.ToLower(k)] = v
}

a.Headers = lowerHeaders
a.headerSettings = lowerHeaders
return nil
}

Expand All @@ -162,8 +195,8 @@ func (a *AuditedHeadersConfig) ApplyConfig(ctx context.Context, headers map[stri
lowerHeaders[strings.ToLower(k)] = v
}

result = make(map[string][]string, len(a.Headers))
for key, settings := range a.Headers {
result = make(map[string][]string, len(a.headerSettings))
for key, settings := range a.headerSettings {
if val, ok := lowerHeaders[key]; ok {
// copy the header values so we don't overwrite them
hVals := make([]string, len(val))
Expand Down
82 changes: 62 additions & 20 deletions vault/audited_headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ func mockAuditedHeadersConfig(t *testing.T) *AuditedHeadersConfig {
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "foo/")
return &AuditedHeadersConfig{
Headers: make(map[string]*auditedHeaderSettings),
view: view,
headerSettings: make(map[string]*auditedHeaderSettings),
view: view,
}
}

Expand All @@ -66,7 +66,7 @@ func testAuditedHeadersConfig_Add(t *testing.T, conf *AuditedHeadersConfig) {
t.Fatalf("Error when adding header to config: %s", err)
}

settings, ok := conf.Headers["x-test-header"]
settings, ok := conf.headerSettings["x-test-header"]
if !ok {
t.Fatal("Expected header to be found in config")
}
Expand Down Expand Up @@ -104,7 +104,7 @@ func testAuditedHeadersConfig_Add(t *testing.T, conf *AuditedHeadersConfig) {
t.Fatalf("Error when adding header to config: %s", err)
}

settings, ok = conf.Headers["x-vault-header"]
settings, ok = conf.headerSettings["x-vault-header"]
if !ok {
t.Fatal("Expected header to be found in config")
}
Expand Down Expand Up @@ -142,7 +142,7 @@ func testAuditedHeadersConfig_Remove(t *testing.T, conf *AuditedHeadersConfig) {
t.Fatalf("Error when adding header to config: %s", err)
}

_, ok := conf.Headers["x-Test-HeAder"]
_, ok := conf.headerSettings["x-Test-HeAder"]
if ok {
t.Fatal("Expected header to not be found in config")
}
Expand Down Expand Up @@ -176,7 +176,7 @@ func testAuditedHeadersConfig_Remove(t *testing.T, conf *AuditedHeadersConfig) {
t.Fatalf("Error when adding header to config: %s", err)
}

_, ok = conf.Headers["x-vault-header"]
_, ok = conf.headerSettings["x-vault-header"]
if ok {
t.Fatal("Expected header to not be found in config")
}
Expand Down Expand Up @@ -355,11 +355,11 @@ func TestAuditedHeadersConfig_ApplyConfig_HashStringError(t *testing.T) {

func BenchmarkAuditedHeaderConfig_ApplyConfig(b *testing.B) {
conf := &AuditedHeadersConfig{
Headers: make(map[string]*auditedHeaderSettings),
view: nil,
headerSettings: make(map[string]*auditedHeaderSettings),
view: nil,
}

conf.Headers = map[string]*auditedHeaderSettings{
conf.headerSettings = map[string]*auditedHeaderSettings{
"X-Test-Header": {false},
"X-Vault-Header": {true},
}
Expand Down Expand Up @@ -404,7 +404,7 @@ func TestAuditedHeaders_invalidate(t *testing.T) {
view := NewBarrierView(barrier, auditedHeadersSubPath)
ahc, err := NewAuditedHeadersConfig(view)
require.NoError(t, err)
require.Len(t, ahc.Headers, 0)
require.Len(t, ahc.headerSettings, 0)

// Store some data using the view.
fakeHeaders1 := map[string]*auditedHeaderSettings{"x-magic-header": {}}
Expand All @@ -416,8 +416,8 @@ func TestAuditedHeaders_invalidate(t *testing.T) {
// Invalidate and check we now see the header we stored
err = ahc.invalidate(context.Background())
require.NoError(t, err)
require.Len(t, ahc.Headers, 1)
_, ok := ahc.Headers["x-magic-header"]
require.Len(t, ahc.headerSettings, 1)
_, ok := ahc.headerSettings["x-magic-header"]
require.True(t, ok)

// Do it again with more headers and random casing.
Expand All @@ -433,10 +433,10 @@ func TestAuditedHeaders_invalidate(t *testing.T) {
// Invalidate and check we now see the header we stored
err = ahc.invalidate(context.Background())
require.NoError(t, err)
require.Len(t, ahc.Headers, 2)
_, ok = ahc.Headers["x-magic-header"]
require.Len(t, ahc.headerSettings, 2)
_, ok = ahc.headerSettings["x-magic-header"]
require.True(t, ok)
_, ok = ahc.Headers["x-even-more-magic-header"]
_, ok = ahc.headerSettings["x-even-more-magic-header"]
require.True(t, ok)
}

Expand All @@ -447,7 +447,7 @@ func TestAuditedHeaders_invalidate_nil_view(t *testing.T) {
view := NewBarrierView(barrier, auditedHeadersSubPath)
ahc, err := NewAuditedHeadersConfig(view)
require.NoError(t, err)
require.Len(t, ahc.Headers, 0)
require.Len(t, ahc.headerSettings, 0)

// Store some data using the view.
fakeHeaders1 := map[string]*auditedHeaderSettings{"x-magic-header": {}}
Expand All @@ -459,8 +459,8 @@ func TestAuditedHeaders_invalidate_nil_view(t *testing.T) {
// Invalidate and check we now see the header we stored
err = ahc.invalidate(context.Background())
require.NoError(t, err)
require.Len(t, ahc.Headers, 1)
_, ok := ahc.Headers["x-magic-header"]
require.Len(t, ahc.headerSettings, 1)
_, ok := ahc.headerSettings["x-magic-header"]
require.True(t, ok)

// Swap out the view with a mock that returns nil when we try to invalidate.
Expand All @@ -472,7 +472,7 @@ func TestAuditedHeaders_invalidate_nil_view(t *testing.T) {
// Invalidate should clear out the existing headers without error
err = ahc.invalidate(context.Background())
require.NoError(t, err)
require.Len(t, ahc.Headers, 0)
require.Len(t, ahc.headerSettings, 0)
}

// TestAuditedHeaders_invalidate_bad_data ensures that we correctly error if the
Expand All @@ -482,7 +482,7 @@ func TestAuditedHeaders_invalidate_bad_data(t *testing.T) {
view := NewBarrierView(barrier, auditedHeadersSubPath)
ahc, err := NewAuditedHeadersConfig(view)
require.NoError(t, err)
require.Len(t, ahc.Headers, 0)
require.Len(t, ahc.headerSettings, 0)

// Store some bad data using the view.
badBytes, err := json.Marshal("i am bad")
Expand All @@ -495,3 +495,45 @@ func TestAuditedHeaders_invalidate_bad_data(t *testing.T) {
require.Error(t, err)
require.ErrorContains(t, err, "failed to parse config")
}

// TestAuditedHeaders_header checks we can return a copy of settings associated with
// an existing header, and we also know when a header wasn't found.
func TestAuditedHeaders_header(t *testing.T) {
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, auditedHeadersSubPath)
ahc, err := NewAuditedHeadersConfig(view)
require.NoError(t, err)
require.Len(t, ahc.headerSettings, 0)

err = ahc.add(context.Background(), "juan", true)
require.NoError(t, err)
require.Len(t, ahc.headerSettings, 1)

s, ok := ahc.header("juan")
require.True(t, ok)
require.Equal(t, true, s.HMAC)

s, ok = ahc.header("x-magic-token")
require.False(t, ok)
}

// TestAuditedHeaders_headers checks we are able to return a copy of the existing
// configured headers.
func TestAuditedHeaders_headers(t *testing.T) {
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, auditedHeadersSubPath)
ahc, err := NewAuditedHeadersConfig(view)
require.NoError(t, err)
require.Len(t, ahc.headerSettings, 0)

err = ahc.add(context.Background(), "juan", true)
require.NoError(t, err)
err = ahc.add(context.Background(), "john", false)
require.NoError(t, err)
require.Len(t, ahc.headerSettings, 2)

s := ahc.headers()
require.Len(t, s, 2)
require.Equal(t, true, s["juan"].HMAC)
require.Equal(t, false, s["john"].HMAC)
}
17 changes: 7 additions & 10 deletions vault/logical_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -1139,8 +1139,7 @@ func (b *SystemBackend) handleAuditedHeaderUpdate(ctx context.Context, req *logi
return logical.ErrorResponse("missing header name"), nil
}

headerConfig := b.Core.AuditedHeadersConfig()
err := headerConfig.add(ctx, header, hmac)
err := b.Core.AuditedHeadersConfig().add(ctx, header, hmac)
if err != nil {
return nil, err
}
Expand All @@ -1155,8 +1154,7 @@ func (b *SystemBackend) handleAuditedHeaderDelete(ctx context.Context, req *logi
return logical.ErrorResponse("missing header name"), nil
}

headerConfig := b.Core.AuditedHeadersConfig()
err := headerConfig.remove(ctx, header)
err := b.Core.AuditedHeadersConfig().remove(ctx, header)
if err != nil {
return nil, err
}
Expand All @@ -1165,14 +1163,13 @@ func (b *SystemBackend) handleAuditedHeaderDelete(ctx context.Context, req *logi
}

// handleAuditedHeaderRead returns the header configuration for the given header name
func (b *SystemBackend) handleAuditedHeaderRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
func (b *SystemBackend) handleAuditedHeaderRead(_ context.Context, _ *logical.Request, d *framework.FieldData) (*logical.Response, error) {
header := d.Get("header").(string)
if header == "" {
return logical.ErrorResponse("missing header name"), nil
}

headerConfig := b.Core.AuditedHeadersConfig()
settings, ok := headerConfig.Headers[strings.ToLower(header)]
settings, ok := b.Core.AuditedHeadersConfig().header(header)
if !ok {
return logical.ErrorResponse("Could not find header in config"), nil
}
Expand All @@ -1185,12 +1182,12 @@ func (b *SystemBackend) handleAuditedHeaderRead(ctx context.Context, req *logica
}

// handleAuditedHeadersRead returns the whole audited headers config
func (b *SystemBackend) handleAuditedHeadersRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
headerConfig := b.Core.AuditedHeadersConfig()
func (b *SystemBackend) handleAuditedHeadersRead(_ context.Context, _ *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
headerSettings := b.Core.AuditedHeadersConfig().headers()

return &logical.Response{
Data: map[string]interface{}{
"headers": headerConfig.Headers,
"headers": headerSettings,
},
}, nil
}
Expand Down

0 comments on commit 54e19c5

Please sign in to comment.