diff --git a/README.md b/README.md index a6e1f04..48753f4 100644 --- a/README.md +++ b/README.md @@ -491,9 +491,12 @@ the corresponding `creds/:name` path. #### `GET` (`read`) Retrieve a new access token by performing a token exchange request on demand. -The token exchange operation always sends the access token from the +The token exchange operation sends the access token from the corresponding credential as the subject token and explicitly requests a new access token from the authorization server. +Reuses previous token that was made with the same parameters +if the provider specified an expiration time +and the token is not yet expired or close to it. Parameters: @@ -502,6 +505,7 @@ Parameters: | `scopes` | A list of explicit scopes to request. | List of String | None | No | | `audiences` | A list of explicit audiences to request. | List of String | None | No | | `resources` | A list of explicit resources to request. | List of String | None | No | +| `minimum_seconds` | Minimum additional duration to require the access token to be valid for. | Integer | 10[3](#footnote-3) | No | ## Providers diff --git a/pkg/backend/path_creds.go b/pkg/backend/path_creds.go index 8aa9e76..6fbb9e7 100644 --- a/pkg/backend/path_creds.go +++ b/pkg/backend/path_creds.go @@ -71,7 +71,7 @@ func (b *backend) credsReadOperation(ctx context.Context, req *logical.Request, } return logical.ErrorResponse("token pending issuance"), nil - case !b.tokenValid(entry.Token, expiryDelta): + case !b.tokenValid(entry.Token.Token, expiryDelta): if entry.AuthServerError != "" { return logical.ErrorResponse("server %q has configuration problems: %s", entry.AuthServerName, entry.AuthServerError), nil } else if entry.UserError != "" { diff --git a/pkg/backend/path_self.go b/pkg/backend/path_self.go index dde4fa9..277266a 100644 --- a/pkg/backend/path_self.go +++ b/pkg/backend/path_self.go @@ -32,7 +32,7 @@ func (b *backend) selfReadOperation(ctx context.Context, req *logical.Request, d return nil, err case entry == nil: return nil, nil - case !b.tokenValid(entry.Token, expiryDelta): + case !b.tokenValid(entry.Token.Token, expiryDelta): return logical.ErrorResponse("token expired"), nil } diff --git a/pkg/backend/path_sts.go b/pkg/backend/path_sts.go index a8b4dbb..e92eccf 100644 --- a/pkg/backend/path_sts.go +++ b/pkg/backend/path_sts.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "strings" + "time" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/logical" @@ -11,13 +12,15 @@ import ( "github.com/puppetlabs/leg/errmap/pkg/errmark" "github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/persistence" "github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/provider" + "golang.org/x/oauth2" ) func (b *backend) stsReadOperation(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + keyer := persistence.AuthCodeName(data.Get("name").(string)) entry, err := b.getRefreshCredToken( ctx, req.Storage, - persistence.AuthCodeName(data.Get("name").(string)), + keyer, defaultExpiryDelta, ) switch { @@ -33,7 +36,7 @@ func (b *backend) stsReadOperation(ctx context.Context, req *logical.Request, da } return logical.ErrorResponse("token pending issuance"), nil - case !b.tokenValid(entry.Token, defaultExpiryDelta): + case !b.tokenValid(entry.Token.Token, defaultExpiryDelta): if entry.AuthServerError != "" { return logical.ErrorResponse("server %q has configuration problems: %s", entry.AuthServerName, entry.AuthServerError), nil } else if entry.UserError != "" { @@ -43,26 +46,60 @@ func (b *backend) stsReadOperation(ctx context.Context, req *logical.Request, da return logical.ErrorResponse("token expired"), nil } - ops, put, err := b.getProviderOperations(ctx, req.Storage, persistence.AuthServerName(entry.AuthServerName), defaultExpiryDelta) - if errmark.MarkedUser(err) { - return logical.ErrorResponse(fmt.Errorf("server %q has configuration problems: %w", entry.AuthServerName, errmark.MarkShort(err)).Error()), nil - } else if err != nil { - return nil, err - } - defer put() + scopes := data.Get("scopes").([]string) + audiences := data.Get("audiences").([]string) + resources := data.Get("resources").([]string) + exchangeKey := "scopes=" + strings.Join(scopes, " ") + + ",audiences=" + strings.Join(audiences, " ") + + ",resources=" + strings.Join(resources, " ") + expiryDelta := time.Duration(data.Get("minimum_seconds").(int)) * time.Second - tok, err := ops.TokenExchange( - ctx, - entry.Token, - provider.WithScopes(data.Get("scopes").([]string)), - provider.WithAudiences(data.Get("audiences").([]string)), - provider.WithResources(data.Get("resources").([]string)), - provider.WithProviderOptions(entry.ProviderOptions), - ) - if errmark.MarkedUser(err) { - return logical.ErrorResponse(errmap.Wrap(errmark.MarkShort(err), "exchange failed").Error()), nil - } else if err != nil { - return nil, err + tok, ok := entry.ExchangedTokens[exchangeKey] + if !ok || !b.tokenValid(tok, expiryDelta) { + ops, put, err := b.getProviderOperations(ctx, req.Storage, persistence.AuthServerName(entry.AuthServerName), defaultExpiryDelta) + if errmark.MarkedUser(err) { + return logical.ErrorResponse(fmt.Errorf("server %q has configuration problems: %w", entry.AuthServerName, errmark.MarkShort(err)).Error()), nil + } else if err != nil { + return nil, err + } + defer put() + + exchangedTok, err := ops.TokenExchange( + ctx, + entry.Token, + provider.WithScopes(scopes), + provider.WithAudiences(audiences), + provider.WithResources(resources), + provider.WithProviderOptions(entry.ProviderOptions), + ) + if errmark.MarkedUser(err) { + return logical.ErrorResponse(errmap.Wrap(errmark.MarkShort(err), "exchange failed").Error()), nil + } else if err != nil { + return nil, err + } + if !b.tokenValid(exchangedTok.Token, expiryDelta) { + return logical.ErrorResponse("token expired"), nil + } + + // copy into smaller struct for caching + tok = &oauth2.Token{ + AccessToken: exchangedTok.Token.AccessToken, + TokenType: exchangedTok.Token.TokenType, + Expiry: exchangedTok.Token.Expiry, + } + + if !tok.Expiry.IsZero() { + // Cache the token since it has an expiration time + err = b.storeExchangedToken( + ctx, + req.Storage, + keyer, + exchangeKey, + tok) + if err != nil { + return nil, err + } + } } rd := map[string]interface{}{ @@ -103,6 +140,12 @@ var stsFields = map[string]*framework.FieldSchema{ Description: "Specifies the target RFC 8707 resource indicators for the minted token.", Query: true, }, + "minimum_seconds": { + Type: framework.TypeDurationSecond, + Description: "Minimum remaining seconds to allow when reusing exchanged access token.", + Default: 0, + Query: true, + }, } const stsHelpSynopsis = ` diff --git a/pkg/backend/token.go b/pkg/backend/token.go index 31744d8..7152a03 100644 --- a/pkg/backend/token.go +++ b/pkg/backend/token.go @@ -4,14 +4,14 @@ import ( "time" "github.com/puppetlabs/leg/timeutil/pkg/clock" - "github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/provider" + "golang.org/x/oauth2" ) const ( defaultExpiryDelta = 10 * time.Second ) -func tokenExpired(clk clock.Clock, t *provider.Token, expiryDelta time.Duration) bool { +func tokenExpired(clk clock.Clock, t *oauth2.Token, expiryDelta time.Duration) bool { if t.Expiry.IsZero() { return false } @@ -23,6 +23,6 @@ func tokenExpired(clk clock.Clock, t *provider.Token, expiryDelta time.Duration) return t.Expiry.Round(0).Add(-expiryDelta).Before(clk.Now()) } -func (b *backend) tokenValid(tok *provider.Token, expiryDelta time.Duration) bool { +func (b *backend) tokenValid(tok *oauth2.Token, expiryDelta time.Duration) bool { return tok != nil && tok.AccessToken != "" && !tokenExpired(b.clock, tok, expiryDelta) } diff --git a/pkg/backend/token_authcode.go b/pkg/backend/token_authcode.go index a235d67..a74a489 100644 --- a/pkg/backend/token_authcode.go +++ b/pkg/backend/token_authcode.go @@ -16,6 +16,7 @@ import ( "github.com/puppetlabs/leg/timeutil/pkg/retry" "github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/persistence" "github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/provider" + "golang.org/x/oauth2" ) type refreshProcess struct { @@ -110,7 +111,7 @@ func (b *backend) refreshCredToken(ctx context.Context, storage logical.Storage, switch { case err != nil || candidate == nil: return err - case !candidate.TokenIssued() || b.tokenValid(candidate.Token, expiryDelta) || candidate.RefreshToken == "": + case !candidate.TokenIssued() || b.tokenValid(candidate.Token.Token, expiryDelta) || candidate.RefreshToken == "": entry = candidate return nil } @@ -155,9 +156,42 @@ func (b *backend) getRefreshCredToken(ctx context.Context, storage logical.Stora return nil, err case entry == nil: return nil, nil - case !entry.TokenIssued() || b.tokenValid(entry.Token, expiryDelta): + case !entry.TokenIssued() || b.tokenValid(entry.Token.Token, expiryDelta): return entry, nil default: return b.refreshCredToken(ctx, storage, keyer, expiryDelta) } } + +func (b *backend) storeExchangedToken(ctx context.Context, storage logical.Storage, keyer persistence.AuthCodeKeyer, exchangeKey string, tok *oauth2.Token) error { + ctx = clockctx.WithClock(ctx, b.clock) + + err := b.data.AuthCode.WithLock(keyer, func(ach *persistence.LockedAuthCodeHolder) error { + acm := ach.Manager(storage) + + entry, err := acm.ReadAuthCodeEntry(ctx) + if err != nil || entry == nil { + return err + } + + if entry.ExchangedTokens == nil { + // first time, make the map + entry.ExchangedTokens = make(map[string]*oauth2.Token) + } else { + // remove every expired exchanged token while we're here + for k, t := range entry.ExchangedTokens { + if !b.tokenValid(t, defaultExpiryDelta) { + delete(entry.ExchangedTokens, k) + } + } + } + entry.ExchangedTokens[exchangeKey] = tok + + if err := acm.WriteAuthCodeEntry(ctx, entry); err != nil { + return err + } + + return nil + }) + return err +} diff --git a/pkg/backend/token_authcode_test.go b/pkg/backend/token_authcode_test.go index 669091e..1118d92 100644 --- a/pkg/backend/token_authcode_test.go +++ b/pkg/backend/token_authcode_test.go @@ -3,6 +3,7 @@ package backend_test import ( "context" "fmt" + "strings" "testing" "time" @@ -14,6 +15,7 @@ import ( "github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/provider" "github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/testutil" "github.com/stretchr/testify/require" + "golang.org/x/oauth2" testclock "k8s.io/utils/clock/testing" ) @@ -290,8 +292,14 @@ func TestMinimumSeconds(t *testing.T) { Secret: "def", } - exchanges := map[string]testutil.MockAuthCodeExchangeFunc{ - "first": testutil.RandomMockAuthCodeExchange, + authCodeExchanges := map[string]testutil.MockAuthCodeExchangeFunc{ + "first": testutil.StaticMockAuthCodeExchange( + &provider.Token{ + Token: &oauth2.Token{ + AccessToken: "first", + }, + }, + ), "second": testutil.RefreshableMockAuthCodeExchange( testutil.IncrementMockAuthCodeExchange("second_"), func(i int) (time.Duration, error) { @@ -300,9 +308,31 @@ func TestMinimumSeconds(t *testing.T) { }, ), } + tokenExchanges := map[string]testutil.MockTokenExchangeFunc{ + // use same names as above to re-use the refresh tokens + "first": testutil.StaticMockTokenExchange( + &provider.Token{ + Token: &oauth2.Token{ + AccessToken: "first", + }, + }, + ), + "second": testutil.ExpiringMockTokenExchangeStep( + testutil.IncrementMockTokenExchange("second_"), + func(i int) (time.Duration, error) { + // add 30 seconds for each subsequent read + return (70 + time.Duration(i)*30) * time.Second, nil + }, + ), + } pr := provider.NewRegistry() - pr.MustRegister("mock", testutil.MockFactory(testutil.MockWithAuthCodeExchange(client, testutil.RestrictMockAuthCodeExchange(exchanges)))) + pr.MustRegister("mock", testutil.MockFactory( + testutil.MockWithAuthCodeExchange(client, + testutil.RestrictMockAuthCodeExchange(authCodeExchanges)), + testutil.MockWithTokenExchange(client, + testutil.RestrictMockTokenExchange(tokenExchanges)), + )) storage := &logical.InmemStorage{} @@ -331,7 +361,7 @@ func TestMinimumSeconds(t *testing.T) { require.Nil(t, resp) // Write our credentials. - for code := range exchanges { + for code := range authCodeExchanges { req = &logical.Request{ Operation: logical.UpdateOperation, Path: backend.CredsPathPrefix + code, @@ -347,6 +377,19 @@ func TestMinimumSeconds(t *testing.T) { require.False(t, resp != nil && resp.IsError(), "response has error: %+v", resp.Error()) require.Nil(t, resp) } + for code := range tokenExchanges { + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: backend.STSPathPrefix + code, + Storage: storage, + } + + resp, err = b.HandleRequest(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + require.False(t, resp.IsError(), "response has error: %+v", resp.Error()) + require.True(t, len(resp.Warnings) == 0) + } tokens := make(map[string]string) tests := []struct { @@ -364,6 +407,7 @@ func TestMinimumSeconds(t *testing.T) { { Name: "make sure minimum_seconds added to first does not generate new token", Token: "first", + Data: map[string]interface{}{"minimum_seconds": "10"}, ExpectedAccessToken: func() string { return tokens["first"] }, }, // The initial token will be issued at +100s (70 seconds + 30 seconds), @@ -384,7 +428,7 @@ func TestMinimumSeconds(t *testing.T) { ExpectedAccessToken: func() string { return "second_2" }, ExpectedExpireTime: true, }, - // If we ask for +140s (> +130s), we should get another refresh. + // If we ask for +140s (> +130s), we should get another refresh to +160s. { Name: "test minimum_seconds more than the expiry of the second token", Token: "second", @@ -392,8 +436,8 @@ func TestMinimumSeconds(t *testing.T) { ExpectedAccessToken: func() string { return "second_3" }, ExpectedExpireTime: true, }, - // Finally, if we ask for something outside the range of what we can - // reasonably issue, we'll just get an error. + // Finally, if we ask for something outside the range of what is issued, + // we'll just get an error. { Name: "verify that second is marked expired if new token is less than request", Token: "second", @@ -401,34 +445,39 @@ func TestMinimumSeconds(t *testing.T) { ExpectedError: "token expired", }, } - for _, test := range tests { - t.Run(test.Name, func(t *testing.T) { - req := &logical.Request{ - Operation: logical.ReadOperation, - Path: backend.CredsPathPrefix + test.Token, - Storage: storage, - Data: test.Data, - } - resp, err := b.HandleRequest(ctx, req) - require.NoError(t, err) - require.NotNil(t, resp) - if test.ExpectedError != "" { - require.True(t, resp.IsError()) - require.EqualError(t, resp.Error(), test.ExpectedError) - } else { - require.False(t, resp.IsError(), "response has error: %+v", resp.Error()) - require.Equal(t, test.ExpectedExpireTime, resp.Data["expire_time"] != nil) - - if test.ExpectedAccessToken != nil { - require.Equal(t, test.ExpectedAccessToken(), resp.Data["access_token"]) - } else { - require.NotEmpty(t, resp.Data["access_token"]) + prefixes := []string{backend.CredsPathPrefix, backend.STSPathPrefix} + for _, prefix := range prefixes { + pfx := strings.ReplaceAll(prefix, "/", "") + for _, test := range tests { + t.Run(pfx+" "+test.Name, func(t *testing.T) { + req := &logical.Request{ + Operation: logical.ReadOperation, + Path: prefix + test.Token, + Storage: storage, + Data: test.Data, } - tokens[test.Token] = resp.Data["access_token"].(string) - } - }) + resp, err := b.HandleRequest(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + if test.ExpectedError != "" { + require.True(t, resp.IsError()) + require.EqualError(t, resp.Error(), test.ExpectedError) + } else { + require.False(t, resp.IsError(), "response has error: %+v", resp.Error()) + require.Equal(t, test.ExpectedExpireTime, resp.Data["expire_time"] != nil) + + if test.ExpectedAccessToken != nil { + require.Equal(t, test.ExpectedAccessToken(), resp.Data["access_token"]) + } else { + require.NotEmpty(t, resp.Data["access_token"]) + } + + tokens[test.Token] = resp.Data["access_token"].(string) + } + }) + } } } diff --git a/pkg/backend/token_clientcreds.go b/pkg/backend/token_clientcreds.go index daa0191..a0d8384 100644 --- a/pkg/backend/token_clientcreds.go +++ b/pkg/backend/token_clientcreds.go @@ -25,7 +25,7 @@ func (b *backend) updateClientCredsToken(ctx context.Context, storage logical.St switch { case err != nil || candidate == nil: return err - case b.tokenValid(candidate.Token, expiryDelta): + case b.tokenValid(candidate.Token.Token, expiryDelta): entry = candidate return nil } @@ -66,7 +66,7 @@ func (b *backend) getUpdateClientCredsToken(ctx context.Context, storage logical switch { case err != nil: return nil, err - case entry != nil && b.tokenValid(entry.Token, expiryDelta): + case entry != nil && b.tokenValid(entry.Token.Token, expiryDelta): return entry, nil default: return b.updateClientCredsToken(ctx, storage, keyer, expiryDelta) diff --git a/pkg/persistence/authcode.go b/pkg/persistence/authcode.go index 6f5ffc9..2c0c015 100644 --- a/pkg/persistence/authcode.go +++ b/pkg/persistence/authcode.go @@ -15,6 +15,7 @@ import ( "github.com/puppetlabs/leg/timeutil/pkg/clockctx" "github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/provider" "github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/vaultext" + "golang.org/x/oauth2" ) const ( @@ -68,6 +69,9 @@ type AuthCodeEntry struct { // If the most recent exchange did not succeed, this holds the time that // exchange occurred. LastAttemptedIssueTime time.Time `json:"last_attempted_issue_time,omitempty"` + + // Cache of successfully exchanged tokens + ExchangedTokens map[string]*oauth2.Token `json:"exchanged_tokens"` } func (ace *AuthCodeEntry) SetToken(ctx context.Context, tok *provider.Token) { diff --git a/pkg/testutil/mock_tokenexchange.go b/pkg/testutil/mock_tokenexchange.go index 2485803..03dcc23 100644 --- a/pkg/testutil/mock_tokenexchange.go +++ b/pkg/testutil/mock_tokenexchange.go @@ -3,6 +3,7 @@ package testutil import ( "fmt" "net/http" + "strings" "sync/atomic" "time" @@ -39,6 +40,20 @@ func ExpiringMockTokenExchange(fn MockTokenExchangeFunc, duration time.Duration) }) } +func ExpiringMockTokenExchangeStep(fn MockTokenExchangeFunc, step func(i int) (time.Duration, error)) MockTokenExchangeFunc { + var i int32 + + return AmendTokenMockTokenExchange(fn, func(t *provider.Token) error { + exp, err := step(int(atomic.AddInt32(&i, 1))) + if err != nil { + return err + } + + t.Expiry = time.Now().Add(exp) + return nil + }) +} + func IncrementMockTokenExchange(prefix string) MockTokenExchangeFunc { var i int32 @@ -64,7 +79,15 @@ func FilterMockTokenExchange(fn MockTokenExchangeFunc, filters ...func(t *provid func RestrictMockTokenExchange(m map[string]MockTokenExchangeFunc) MockTokenExchangeFunc { return func(t *provider.Token, opts *provider.TokenExchangeOptions) (*provider.Token, error) { - fn, found := m[t.AccessToken] + found := false + var name string + var fn MockTokenExchangeFunc + for name, fn = range m { + if strings.HasPrefix(t.AccessToken, name) { + found = true + break + } + } if !found { return nil, MockErrorResponse(http.StatusForbidden, &interop.JSONError{Error: "access_denied"}) }