Skip to content

Commit

Permalink
Token revocation refactor (#4512)
Browse files Browse the repository at this point in the history
* Hand off lease expiration to expiration manager via timers

* Use sync.Map as the cache to track token deletion state

* Add CreateOrFetchRevocationLeaseByToken to hand off token revocation to exp manager

* Update revoke and revoke-self handlers

* Fix tests

* revokeSalted: Move token entry deletion into the deferred func

* Fix test race

* Add blocking lease revocation test

* Remove test log

* Add HandlerFunc on NoopBackend, adjust locks, and add test

* Add sleep to allow for revocations to settle

* Various updates

* Rename some functions and variables to be more clear
* Change step-down and seal to use expmgr for revoke functionality like
during request handling
* Attempt to WAL the token as being invalid as soon as possible so that
further usage will fail even if revocation does not fully complete

* Address feedback

* Return invalid lease on negative TTL

* Revert "Return invalid lease on negative TTL"

This reverts commit a39597e.

* Extend sleep on tests
  • Loading branch information
calvn authored May 10, 2018
1 parent 2ef3635 commit 0678d6b
Show file tree
Hide file tree
Showing 10 changed files with 356 additions and 135 deletions.
14 changes: 10 additions & 4 deletions vault/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -1429,10 +1429,13 @@ func (c *Core) sealInitCommon(ctx context.Context, req *logical.Request) (retErr
return retErr
}

if te != nil && te.NumUses == -1 {
if te != nil && te.NumUses == tokenRevocationPending {
// Token needs to be revoked. We do this immediately here because
// we won't have a token store after sealing.
err = c.tokenStore.Revoke(c.activeContext, te.ID)
leaseID, err := c.expiration.CreateOrFetchRevocationLeaseByToken(te)
if err == nil {
err = c.expiration.Revoke(leaseID)
}
if err != nil {
c.logger.Error("token needed revocation before seal but failed to revoke", "error", err)
retErr = multierror.Append(retErr, ErrInternalError)
Expand Down Expand Up @@ -1540,10 +1543,13 @@ func (c *Core) StepDown(req *logical.Request) (retErr error) {
return retErr
}

if te != nil && te.NumUses == -1 {
if te != nil && te.NumUses == tokenRevocationPending {
// Token needs to be revoked. We do this immediately here because
// we won't have a token store after sealing.
err = c.tokenStore.Revoke(c.activeContext, te.ID)
leaseID, err := c.expiration.CreateOrFetchRevocationLeaseByToken(te)
if err == nil {
err = c.expiration.Revoke(leaseID)
}
if err != nil {
c.logger.Error("token needed revocation before step-down but failed to revoke", "error", err)
retErr = multierror.Append(retErr, ErrInternalError)
Expand Down
80 changes: 73 additions & 7 deletions vault/expiration.go
Original file line number Diff line number Diff line change
Expand Up @@ -561,18 +561,34 @@ func (m *ExpirationManager) RevokeByToken(te *TokenEntry) error {
defer metrics.MeasureSince([]string{"expire", "revoke-by-token"}, time.Now())

// Lookup the leases
existing, err := m.lookupByToken(te.ID)
existing, err := m.lookupLeasesByToken(te.ID)
if err != nil {
return errwrap.Wrapf("failed to scan for leases: {{err}}", err)
}

// Revoke all the keys
for idx, leaseID := range existing {
if err := m.revokeCommon(leaseID, false, false); err != nil {
return errwrap.Wrapf(fmt.Sprintf("failed to revoke %q (%d / %d): {{err}}", leaseID, idx+1, len(existing)), err)
for _, leaseID := range existing {
// Load the entry
le, err := m.loadEntry(leaseID)
if err != nil {
return err
}

// If there's a lease, set expiration to now, persist, and call
// updatePending to hand off revocation to the expiration manager's pending
// timer map
if le != nil {
le.ExpireTime = time.Now()

if err := m.persistEntry(le); err != nil {
return err
}

m.updatePending(le, 0)
}
}

// te.Path should never be empty, but we check just in case
if te.Path != "" {
saltedID, err := m.tokenStore.SaltID(m.quitContext, te.ID)
if err != nil {
Expand Down Expand Up @@ -1054,7 +1070,7 @@ func (m *ExpirationManager) revokeEntry(le *leaseEntry) error {
// Revocation of login tokens is special since we can by-pass the
// backend and directly interact with the token store
if le.Auth != nil {
if err := m.tokenStore.RevokeTree(m.quitContext, le.ClientToken); err != nil {
if err := m.tokenStore.revokeTree(m.quitContext, le.ClientToken); err != nil {
return errwrap.Wrapf("failed to revoke token: {{err}}", err)
}

Expand Down Expand Up @@ -1247,8 +1263,58 @@ func (m *ExpirationManager) removeIndexByToken(token, leaseID string) error {
return nil
}

// lookupByToken is used to lookup all the leaseID's via the
func (m *ExpirationManager) lookupByToken(token string) ([]string, error) {
// CreateOrFetchRevocationLeaseByToken is used to create or fetch the matching
// leaseID for a particular token. The lease is set to expire immediately after
// it's created.
func (m *ExpirationManager) CreateOrFetchRevocationLeaseByToken(te *TokenEntry) (string, error) {
// Fetch the saltedID of the token and construct the leaseID
saltedID, err := m.tokenStore.SaltID(m.quitContext, te.ID)
if err != nil {
return "", err
}
leaseID := path.Join(te.Path, saltedID)

// Load the entry
le, err := m.loadEntry(leaseID)
if err != nil {
return "", err
}

// If there's no associated leaseEntry for the token, we create one
if le == nil {
auth := &logical.Auth{
ClientToken: te.ID,
LeaseOptions: logical.LeaseOptions{
TTL: time.Nanosecond,
},
}

if strings.Contains(te.Path, "..") {
return "", consts.ErrPathContainsParentReferences
}

// Create a lease entry
now := time.Now()
le = &leaseEntry{
LeaseID: leaseID,
ClientToken: auth.ClientToken,
Auth: auth,
Path: te.Path,
IssueTime: now,
ExpireTime: now.Add(time.Nanosecond),
}

// Encode the entry
if err := m.persistEntry(le); err != nil {
return "", err
}
}

return le.LeaseID, nil
}

// lookupLeasesByToken is used to lookup all the leaseID's via the tokenID
func (m *ExpirationManager) lookupLeasesByToken(token string) ([]string, error) {
saltedID, err := m.tokenStore.SaltID(m.quitContext, token)
if err != nil {
return nil, err
Expand Down
104 changes: 104 additions & 0 deletions vault/expiration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,108 @@ func TestExpiration_RevokeByToken(t *testing.T) {
t.Fatalf("err: %v", err)
}

time.Sleep(300 * time.Millisecond)

noop.Lock()
defer noop.Unlock()

if len(noop.Requests) != 3 {
t.Fatalf("Bad: %v", noop.Requests)
}
for _, req := range noop.Requests {
if req.Operation != logical.RevokeOperation {
t.Fatalf("Bad: %v", req)
}
}

expect := []string{
"foo",
"sub/bar",
"zip",
}
sort.Strings(noop.Paths)
sort.Strings(expect)
if !reflect.DeepEqual(noop.Paths, expect) {
t.Fatalf("bad: %v", noop.Paths)
}
}

func TestExpiration_RevokeByToken_Blocking(t *testing.T) {
exp := mockExpiration(t)
noop := &NoopBackend{}
// Request handle with a timeout context that simulates blocking lease revocation.
noop.RequestHandler = func(ctx context.Context, req *logical.Request) (*logical.Response, error) {
ctx, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
defer cancel()

select {
case <-ctx.Done():
return noop.Response, nil
}
}

_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/")
meUUID, err := uuid.GenerateUUID()
if err != nil {
t.Fatal(err)
}
err = exp.router.Mount(noop, "prod/aws/", &MountEntry{Path: "prod/aws/", Type: "noop", UUID: meUUID, Accessor: "noop-accessor"}, view)
if err != nil {
t.Fatal(err)
}

paths := []string{
"prod/aws/foo",
"prod/aws/sub/bar",
"prod/aws/zip",
}
for _, path := range paths {
req := &logical.Request{
Operation: logical.ReadOperation,
Path: path,
ClientToken: "foobarbaz",
}
resp := &logical.Response{
Secret: &logical.Secret{
LeaseOptions: logical.LeaseOptions{
TTL: 1 * time.Minute,
},
},
Data: map[string]interface{}{
"access_key": "xyz",
"secret_key": "abcd",
},
}
_, err := exp.Register(req, resp)
if err != nil {
t.Fatalf("err: %v", err)
}
}

// Should nuke all the keys
te := &TokenEntry{
ID: "foobarbaz",
}
if err := exp.RevokeByToken(te); err != nil {
t.Fatalf("err: %v", err)
}

// Lock and check that no requests has gone through yet
noop.Lock()
if len(noop.Requests) != 0 {
t.Fatalf("Bad: %v", noop.Requests)
}
noop.Unlock()

// Wait for a bit for timeouts to trigger and pending revocations to go
// through and then we relock
time.Sleep(300 * time.Millisecond)

noop.Lock()
defer noop.Unlock()

// Now make sure that all requests have gone through
if len(noop.Requests) != 3 {
t.Fatalf("Bad: %v", noop.Requests)
}
Expand Down Expand Up @@ -1239,6 +1341,8 @@ func TestExpiration_revokeEntry_token(t *testing.T) {
t.Fatalf("err: %v", err)
}

time.Sleep(300 * time.Millisecond)

out, err := exp.tokenStore.Lookup(context.Background(), le.ClientToken)
if err != nil {
t.Fatalf("err: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion vault/generate_root.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (g generateStandardRootToken) generate(ctx context.Context, c *Core) (strin
}

cleanupFunc := func() {
c.tokenStore.Revoke(ctx, te.ID)
c.tokenStore.revokeOrphan(ctx, te.ID)
}

return te.ID, cleanupFunc, nil
Expand Down
4 changes: 2 additions & 2 deletions vault/logical_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -3184,7 +3184,7 @@ func (b *SystemBackend) responseWrappingUnwrap(ctx context.Context, token string
return "", errwrap.Wrapf("error decrementing wrapping token's use-count: {{err}}", err)
}

defer b.Core.tokenStore.Revoke(ctx, token)
defer b.Core.tokenStore.revokeOrphan(ctx, token)
}

cubbyReq := &logical.Request{
Expand Down Expand Up @@ -3294,7 +3294,7 @@ func (b *SystemBackend) handleWrappingRewrap(ctx context.Context, req *logical.R
if err != nil {
return nil, errwrap.Wrapf("error decrementing wrapping token's use-count: {{err}}", err)
}
defer b.Core.tokenStore.Revoke(ctx, token)
defer b.Core.tokenStore.revokeOrphan(ctx, token)
}

// Fetch the original TTL
Expand Down
11 changes: 7 additions & 4 deletions vault/request_handling.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,15 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp
retErr = multierror.Append(retErr, logical.ErrPermissionDenied)
return nil, nil, retErr
}
if te.NumUses == -1 {
if te.NumUses == tokenRevocationPending {
// We defer a revocation until after logic has run, since this is a
// valid request (this is the token's final use). We pass the ID in
// directly just to be safe in case something else modifies te later.
defer func(id string) {
err = c.tokenStore.Revoke(ctx, id)
leaseID, err := c.expiration.CreateOrFetchRevocationLeaseByToken(te)
if err == nil {
err = c.expiration.Revoke(leaseID)
}
if err != nil {
c.logger.Error("failed to revoke token", "error", err)
retResp = nil
Expand Down Expand Up @@ -398,7 +401,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp
}

if err := c.expiration.RegisterAuth(te.Path, resp.Auth); err != nil {
c.tokenStore.Revoke(ctx, te.ID)
c.tokenStore.revokeOrphan(ctx, te.ID)
c.logger.Error("failed to register token lease", "request_path", req.Path, "error", err)
retErr = multierror.Append(retErr, ErrInternalError)
return nil, auth, retErr
Expand Down Expand Up @@ -604,7 +607,7 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re

// Register with the expiration manager
if err := c.expiration.RegisterAuth(te.Path, auth); err != nil {
c.tokenStore.Revoke(ctx, te.ID)
c.tokenStore.revokeOrphan(ctx, te.ID)
c.logger.Error("failed to register token lease", "request_path", req.Path, "error", err)
return nil, auth, ErrInternalError
}
Expand Down
11 changes: 10 additions & 1 deletion vault/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"github.com/hashicorp/vault/logical"
)

type HandlerFunc func(context.Context, *logical.Request) (*logical.Response, error)

type NoopBackend struct {
sync.Mutex

Expand All @@ -22,12 +24,19 @@ type NoopBackend struct {
Paths []string
Requests []*logical.Request
Response *logical.Response
RequestHandler HandlerFunc
Invalidations []string
DefaultLeaseTTL time.Duration
MaxLeaseTTL time.Duration
}

func (n *NoopBackend) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) {
var err error
resp := n.Response
if n.RequestHandler != nil {
resp, err = n.RequestHandler(ctx, req)
}

n.Lock()
defer n.Unlock()

Expand All @@ -38,7 +47,7 @@ func (n *NoopBackend) HandleRequest(ctx context.Context, req *logical.Request) (
return nil, fmt.Errorf("missing view")
}

return n.Response, nil
return resp, err
}

func (n *NoopBackend) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) {
Expand Down
Loading

0 comments on commit 0678d6b

Please sign in to comment.