Skip to content
This repository has been archived by the owner on Nov 18, 2024. It is now read-only.

Add caching of STS-exchanged access tokens #90

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,12 @@ the corresponding `creds/:name` path.
#### `GET` (`read`)

Retrieve a new access token by performing a token exchange request on demand.
The token exchange operation always sends the access token from the
The token exchange operation sends the access token from the
corresponding credential as the subject token and explicitly requests a new
access token from the authorization server.
Reuses previous token that was made with the same parameters
if the provider specified an expiration time
and the token is not yet expired or close to it.

Parameters:

Expand All @@ -502,6 +505,7 @@ Parameters:
| `scopes` | A list of explicit scopes to request. | List of String | None | No |
| `audiences` | A list of explicit audiences to request. | List of String | None | No |
| `resources` | A list of explicit resources to request. | List of String | None | No |
| `minimum_seconds` | Minimum additional duration to require the access token to be valid for. | Integer | 10<sup id="ret-3-b">[3](#footnote-3)</sup> | No |

## Providers

Expand Down
2 changes: 1 addition & 1 deletion pkg/backend/path_creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (b *backend) credsReadOperation(ctx context.Context, req *logical.Request,
}

return logical.ErrorResponse("token pending issuance"), nil
case !b.tokenValid(entry.Token, expiryDelta):
case !b.tokenValid(entry.Token.Token, expiryDelta):
if entry.AuthServerError != "" {
return logical.ErrorResponse("server %q has configuration problems: %s", entry.AuthServerName, entry.AuthServerError), nil
} else if entry.UserError != "" {
Expand Down
2 changes: 1 addition & 1 deletion pkg/backend/path_self.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (b *backend) selfReadOperation(ctx context.Context, req *logical.Request, d
return nil, err
case entry == nil:
return nil, nil
case !b.tokenValid(entry.Token, expiryDelta):
case !b.tokenValid(entry.Token.Token, expiryDelta):
return logical.ErrorResponse("token expired"), nil
}

Expand Down
87 changes: 65 additions & 22 deletions pkg/backend/path_sts.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,25 @@ import (
"context"
"fmt"
"strings"
"time"

"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
"github.com/puppetlabs/leg/errmap/pkg/errmap"
"github.com/puppetlabs/leg/errmap/pkg/errmark"
"github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/persistence"
"github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/provider"
"golang.org/x/oauth2"
)

func (b *backend) stsReadOperation(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
keyer := persistence.AuthCodeName(data.Get("name").(string))
expiryDelta := time.Duration(data.Get("minimum_seconds").(int)) * time.Second
entry, err := b.getRefreshCredToken(
ctx,
req.Storage,
persistence.AuthCodeName(data.Get("name").(string)),
defaultExpiryDelta,
keyer,
expiryDelta,
)
switch {
case err != nil:
Expand All @@ -33,7 +37,7 @@ func (b *backend) stsReadOperation(ctx context.Context, req *logical.Request, da
}

return logical.ErrorResponse("token pending issuance"), nil
case !b.tokenValid(entry.Token, defaultExpiryDelta):
case !b.tokenValid(entry.Token.Token, expiryDelta):
if entry.AuthServerError != "" {
return logical.ErrorResponse("server %q has configuration problems: %s", entry.AuthServerName, entry.AuthServerError), nil
} else if entry.UserError != "" {
Expand All @@ -43,26 +47,59 @@ func (b *backend) stsReadOperation(ctx context.Context, req *logical.Request, da
return logical.ErrorResponse("token expired"), nil
}

ops, put, err := b.getProviderOperations(ctx, req.Storage, persistence.AuthServerName(entry.AuthServerName), defaultExpiryDelta)
if errmark.MarkedUser(err) {
return logical.ErrorResponse(fmt.Errorf("server %q has configuration problems: %w", entry.AuthServerName, errmark.MarkShort(err)).Error()), nil
} else if err != nil {
return nil, err
}
defer put()
scopes := data.Get("scopes").([]string)
audiences := data.Get("audiences").([]string)
resources := data.Get("resources").([]string)
exchangeKey := "scopes=" + strings.Join(scopes, " ") +
",audiences=" + strings.Join(audiences, " ") +
",resources=" + strings.Join(resources, " ")

tok, err := ops.TokenExchange(
ctx,
entry.Token,
provider.WithScopes(data.Get("scopes").([]string)),
provider.WithAudiences(data.Get("audiences").([]string)),
provider.WithResources(data.Get("resources").([]string)),
provider.WithProviderOptions(entry.ProviderOptions),
)
if errmark.MarkedUser(err) {
return logical.ErrorResponse(errmap.Wrap(errmark.MarkShort(err), "exchange failed").Error()), nil
} else if err != nil {
return nil, err
tok, ok := entry.ExchangedTokens[exchangeKey]
if !ok || !b.tokenValid(tok, expiryDelta) {
ops, put, err := b.getProviderOperations(ctx, req.Storage, persistence.AuthServerName(entry.AuthServerName), defaultExpiryDelta)
if errmark.MarkedUser(err) {
return logical.ErrorResponse(fmt.Errorf("server %q has configuration problems: %w", entry.AuthServerName, errmark.MarkShort(err)).Error()), nil
} else if err != nil {
return nil, err
}
defer put()

exchangedTok, err := ops.TokenExchange(
ctx,
entry.Token,
provider.WithScopes(scopes),
provider.WithAudiences(audiences),
provider.WithResources(resources),
provider.WithProviderOptions(entry.ProviderOptions),
)
if errmark.MarkedUser(err) {
return logical.ErrorResponse(errmap.Wrap(errmark.MarkShort(err), "exchange failed").Error()), nil
} else if err != nil {
return nil, err
}
if !b.tokenValid(exchangedTok.Token, expiryDelta) {
return logical.ErrorResponse("token expired"), nil
}

// copy into smaller struct for caching
tok = &oauth2.Token{
AccessToken: exchangedTok.Token.AccessToken,
TokenType: exchangedTok.Token.TokenType,
Expiry: exchangedTok.Token.Expiry,
}

if !tok.Expiry.IsZero() {
// Cache the token since it has an expiration time
err = b.storeExchangedToken(
ctx,
req.Storage,
keyer,
exchangeKey,
tok)
if err != nil {
return nil, err
}
}
}

rd := map[string]interface{}{
Expand Down Expand Up @@ -103,6 +140,12 @@ var stsFields = map[string]*framework.FieldSchema{
Description: "Specifies the target RFC 8707 resource indicators for the minted token.",
Query: true,
},
"minimum_seconds": {
Type: framework.TypeDurationSecond,
Description: "Minimum remaining seconds to allow when reusing exchanged access token.",
Default: 0,
Query: true,
},
}

const stsHelpSynopsis = `
Expand Down
6 changes: 3 additions & 3 deletions pkg/backend/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ import (
"time"

"github.com/puppetlabs/leg/timeutil/pkg/clock"
"github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/provider"
"golang.org/x/oauth2"
)

const (
defaultExpiryDelta = 10 * time.Second
)

func tokenExpired(clk clock.Clock, t *provider.Token, expiryDelta time.Duration) bool {
func tokenExpired(clk clock.Clock, t *oauth2.Token, expiryDelta time.Duration) bool {
if t.Expiry.IsZero() {
return false
}
Expand All @@ -23,6 +23,6 @@ func tokenExpired(clk clock.Clock, t *provider.Token, expiryDelta time.Duration)
return t.Expiry.Round(0).Add(-expiryDelta).Before(clk.Now())
}

func (b *backend) tokenValid(tok *provider.Token, expiryDelta time.Duration) bool {
func (b *backend) tokenValid(tok *oauth2.Token, expiryDelta time.Duration) bool {
return tok != nil && tok.AccessToken != "" && !tokenExpired(b.clock, tok, expiryDelta)
}
38 changes: 36 additions & 2 deletions pkg/backend/token_authcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/puppetlabs/leg/timeutil/pkg/retry"
"github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/persistence"
"github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/provider"
"golang.org/x/oauth2"
)

type refreshProcess struct {
Expand Down Expand Up @@ -110,7 +111,7 @@ func (b *backend) refreshCredToken(ctx context.Context, storage logical.Storage,
switch {
case err != nil || candidate == nil:
return err
case !candidate.TokenIssued() || b.tokenValid(candidate.Token, expiryDelta) || candidate.RefreshToken == "":
case !candidate.TokenIssued() || b.tokenValid(candidate.Token.Token, expiryDelta) || candidate.RefreshToken == "":
entry = candidate
return nil
}
Expand Down Expand Up @@ -155,9 +156,42 @@ func (b *backend) getRefreshCredToken(ctx context.Context, storage logical.Stora
return nil, err
case entry == nil:
return nil, nil
case !entry.TokenIssued() || b.tokenValid(entry.Token, expiryDelta):
case !entry.TokenIssued() || b.tokenValid(entry.Token.Token, expiryDelta):
return entry, nil
default:
return b.refreshCredToken(ctx, storage, keyer, expiryDelta)
}
}

func (b *backend) storeExchangedToken(ctx context.Context, storage logical.Storage, keyer persistence.AuthCodeKeyer, exchangeKey string, tok *oauth2.Token) error {
ctx = clockctx.WithClock(ctx, b.clock)

err := b.data.AuthCode.WithLock(keyer, func(ach *persistence.LockedAuthCodeHolder) error {
acm := ach.Manager(storage)

entry, err := acm.ReadAuthCodeEntry(ctx)
if err != nil || entry == nil {
return err
}

if entry.ExchangedTokens == nil {
// first time, make the map
entry.ExchangedTokens = make(map[string]*oauth2.Token)
} else {
// remove every expired exchanged token while we're here
for k, t := range entry.ExchangedTokens {
if !b.tokenValid(t, defaultExpiryDelta) {
delete(entry.ExchangedTokens, k)
}
}
}
entry.ExchangedTokens[exchangeKey] = tok

if err := acm.WriteAuthCodeEntry(ctx, entry); err != nil {
return err
}

return nil
})
return err
}
Loading
Loading