Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify approle tidy to validate dangling accessors #4981

Merged
merged 1 commit into from
Jul 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions builtin/credential/approle/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package approle
import (
"context"
"sync"
"time"

"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/locksutil"
Expand Down Expand Up @@ -56,6 +57,8 @@ type backend struct {
// secretIDListingLock is a dedicated lock for listing SecretIDAccessors
// for all the SecretIDs issued against an approle
secretIDListingLock sync.RWMutex

testTidyDelay time.Duration
}

func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
Expand Down
83 changes: 74 additions & 9 deletions builtin/credential/approle/path_tidy_user_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,29 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi
go func() {
defer atomic.StoreUint32(b.tidySecretIDCASGuard, 0)

logger := b.Logger().Named("tidy")

checkCount := 0

defer func() {
if b.testTidyDelay > 0 {
logger.Trace("done checking entries", "num_entries", checkCount)
}
}()

// Don't cancel when the original client request goes away
ctx = context.Background()

logger := b.Logger().Named("tidy")

tidyFunc := func(secretIDPrefixToUse, accessorIDPrefixToUse string) error {
logger.Trace("listing role HMACs", "prefix", secretIDPrefixToUse)

roleNameHMACs, err := s.List(ctx, secretIDPrefixToUse)
if err != nil {
return err
}

logger.Trace("listing accessors", "prefix", accessorIDPrefixToUse)

// List all the accessors and add them all to a map
accessorHashes, err := s.List(ctx, accessorIDPrefixToUse)
if err != nil {
Expand All @@ -59,7 +71,10 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi
accessorMap[accessorHash] = true
}

time.Sleep(b.testTidyDelay)

secretIDCleanupFunc := func(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse string) error {
checkCount++
lock := b.secretIDLock(secretIDHMAC)
lock.Lock()
defer lock.Unlock()
Expand Down Expand Up @@ -91,6 +106,7 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi
return errwrap.Wrapf("failed to read secret ID accessor entry: {{err}}", err)
}
if accessorEntry == nil {
logger.Trace("found nil accessor")
if err := s.Delete(ctx, entryIndex); err != nil {
return errwrap.Wrapf(fmt.Sprintf("error deleting secret ID %q from storage: {{err}}", secretIDHMAC), err)
}
Expand All @@ -99,6 +115,7 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi

// ExpirationTime not being set indicates non-expiring SecretIDs
if !result.ExpirationTime.IsZero() && time.Now().After(result.ExpirationTime) {
logger.Trace("found expired secret ID")
// Clean up the accessor of the secret ID first
err = b.deleteSecretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse)
if err != nil {
Expand Down Expand Up @@ -126,6 +143,7 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi
}

for _, roleNameHMAC := range roleNameHMACs {
logger.Trace("listing secret ID HMACs", "role_hmac", roleNameHMAC)
secretIDHMACs, err := s.List(ctx, fmt.Sprintf("%s%s", secretIDPrefixToUse, roleNameHMAC))
if err != nil {
return err
Expand All @@ -140,13 +158,60 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi

// Accessor indexes were not getting cleaned up until 0.9.3. This is a fix
// to clean up the dangling accessor entries.
for accessorHash, _ := range accessorMap {
// Ideally, locking should be performed here. But for that, accessors
// are required in plaintext, which are not available. Hence performing
// a racy cleanup.
err = s.Delete(ctx, secretIDAccessorPrefix+accessorHash)
if err != nil {
return err
if len(accessorMap) > 0 {
for _, lock := range b.secretIDLocks {
lock.Lock()
defer lock.Unlock()
}
for accessorHash, _ := range accessorMap {
logger.Trace("found dangling accessor, verifying")
// Ideally, locking on accessors should be performed here too
// but for that, accessors are required in plaintext, which are
// not available. The code above helps but it may still be
// racy.
// ...
// Look up the secret again now that we have all the locks. The
// lock is held when writing accessor/secret so if we have the
// lock we know we're not in a
// wrote-accessor-but-not-yet-secret case, which can be racy.
var entry secretIDAccessorStorageEntry
entryIndex := accessorIDPrefixToUse + accessorHash
se, err := s.Get(ctx, entryIndex)
if err != nil {
return err
}
if se != nil {
err = se.DecodeJSON(&entry)
if err != nil {
return err
}

// The storage entry doesn't store the role ID, so we have
// to go about this the long way; fortunately we shouldn't
// actually hit this very often
var found bool
searchloop:
for _, roleNameHMAC := range roleNameHMACs {
secretIDHMACs, err := s.List(ctx, fmt.Sprintf("%s%s", secretIDPrefixToUse, roleNameHMAC))
if err != nil {
return err
}
for _, v := range secretIDHMACs {
if v == entry.SecretIDHMAC {
found = true
logger.Trace("accessor verified, not removing")
break searchloop
}
}
}
if !found {
logger.Trace("could not verify dangling accessor, removing")
err = s.Delete(ctx, entryIndex)
if err != nil {
return err
}
}
}
}
}

Expand Down
94 changes: 93 additions & 1 deletion builtin/credential/approle/path_tidy_user_id_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package approle

import (
"context"
"fmt"
"sync"
"testing"
"time"

"github.com/hashicorp/vault/logical"
)

func TestAppRole_TidyDanglingAccessors(t *testing.T) {
func TestAppRole_TidyDanglingAccessors_Normal(t *testing.T) {
var resp *logical.Response
var err error
b, storage := createBackendWithStorage(t)
Expand Down Expand Up @@ -83,3 +85,93 @@ func TestAppRole_TidyDanglingAccessors(t *testing.T) {
t.Fatalf("bad: len(accessorHashes); expect 1, got %d", len(accessorHashes))
}
}

func TestAppRole_TidyDanglingAccessors_RaceTest(t *testing.T) {
var resp *logical.Response
var err error
b, storage := createBackendWithStorage(t)

b.testTidyDelay = 300 * time.Millisecond

// Create a role
createRole(t, b, storage, "role1", "a,b,c")

// Create an initial entry
roleSecretIDReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "role/role1/secret-id",
Storage: storage,
}
resp, err = b.HandleRequest(context.Background(), roleSecretIDReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
count := 1

wg := sync.WaitGroup{}
now := time.Now()
started := false
for {
if time.Now().Sub(now) > 700*time.Millisecond {
break
}
if time.Now().Sub(now) > 100*time.Millisecond && !started {
started = true
_, err = b.tidySecretID(context.Background(), &logical.Request{
Storage: storage,
})
if err != nil {
t.Fatal(err)
}
}
go func() {
wg.Add(1)
defer wg.Done()
roleSecretIDReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "role/role1/secret-id",
Storage: storage,
}
resp, err = b.HandleRequest(context.Background(), roleSecretIDReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
}()
count++
}

t.Logf("wrote %d entries", count)

wg.Wait()
// Let tidy finish
time.Sleep(1 * time.Second)

// Run tidy again
_, err = b.tidySecretID(context.Background(), &logical.Request{
Storage: storage,
})
if err != nil {
t.Fatal(err)
}
time.Sleep(2 * time.Second)

accessorHashes, err := storage.List(context.Background(), "accessor/")
if err != nil {
t.Fatal(err)
}
if len(accessorHashes) != count {
t.Fatalf("bad: len(accessorHashes); expect %d, got %d", count, len(accessorHashes))
}

roleHMACs, err := storage.List(context.Background(), secretIDPrefix)
if err != nil {
t.Fatal(err)
}
secretIDs, err := storage.List(context.Background(), fmt.Sprintf("%s%s", secretIDPrefix, roleHMACs[0]))
if err != nil {
t.Fatal(err)
}
if len(secretIDs) != count {
t.Fatalf("bad: len(secretIDs); expect %d, got %d", count, len(secretIDs))
}
}