diff --git a/vault/expiration.go b/vault/expiration.go index 69d1098046f6..f6e4fb7e0bd1 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -1147,6 +1147,12 @@ func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request, func (m *ExpirationManager) RegisterAuth(ctx context.Context, te *logical.TokenEntry, auth *logical.Auth) error { defer metrics.MeasureSince([]string{"expire", "register-auth"}, time.Now()) + authExpirationTime := auth.ExpirationTime() + + if te.TTL == 0 && authExpirationTime.IsZero() && (len(te.Policies) != 1 || te.Policies[0] != "root") { + return errors.New("refusing to register a lease for a non-root token with no TTL") + } + if te.Type == logical.TokenTypeBatch { return errors.New("cannot register a lease for a batch token") } @@ -1185,7 +1191,7 @@ func (m *ExpirationManager) RegisterAuth(ctx context.Context, te *logical.TokenE Auth: auth, Path: te.Path, IssueTime: time.Now(), - ExpireTime: auth.ExpirationTime(), + ExpireTime: authExpirationTime, namespace: tokenNS, } diff --git a/vault/expiration_test.go b/vault/expiration_test.go index 942f8eb584bb..cb6af4cc541b 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -528,6 +528,7 @@ func TestExpiration_RegisterAuth_NoLease(t *testing.T) { te := &logical.TokenEntry{ ID: root.ID, Path: "auth/github/login", + Policies: []string{"root"}, NamespaceID: namespace.RootNamespaceID, } err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) @@ -562,6 +563,55 @@ func TestExpiration_RegisterAuth_NoLease(t *testing.T) { } } +// Tests both the expiration function and the core function +func TestExpiration_RegisterAuth_NoTTL(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + exp := c.expiration + ctx := namespace.RootContext(nil) + + root, err := exp.tokenStore.rootToken(context.Background()) + if err != nil { + t.Fatalf("err: %v", err) + } + + auth := &logical.Auth{ + ClientToken: root.ID, + TokenPolicies: []string{"root"}, + } + + // First on core + err = c.RegisterAuth(ctx, 0, "auth/github/login", auth) + if err != nil { + t.Fatal(err) + } + + auth.TokenPolicies[0] = "default" + err = c.RegisterAuth(ctx, 0, "auth/github/login", auth) + if err == nil { + t.Fatal("expected error") + } + + // Now expiration + // Should work, root token with zero TTL + te := &logical.TokenEntry{ + ID: root.ID, + Path: "auth/github/login", + Policies: []string{"root"}, + NamespaceID: namespace.RootNamespaceID, + } + err = exp.RegisterAuth(ctx, te, auth) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Test non-root token with zero TTL + te.Policies = []string{"default"} + err = exp.RegisterAuth(ctx, te, auth) + if err == nil { + t.Fatal("expected error") + } +} + func TestExpiration_Revoke(t *testing.T) { exp := mockExpiration(t) noop := &NoopBackend{} diff --git a/vault/request_handling.go b/vault/request_handling.go index 4f2e44f92ffd..f7381bd8311e 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -869,6 +869,8 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp case logical.TokenTypeBatch: case logical.TokenTypeService: if err := c.expiration.RegisterAuth(ctx, &logical.TokenEntry{ + TTL: auth.TTL, + Policies: auth.TokenPolicies, Path: resp.Auth.CreationPath, NamespaceID: ns.ID, }, resp.Auth); err != nil { @@ -1184,6 +1186,11 @@ func (c *Core) RegisterAuth(ctx context.Context, tokenTTL time.Duration, path st Type: auth.TokenType, } + if te.TTL == 0 && (len(te.Policies) != 1 || te.Policies[0] != "root") { + c.logger.Error("refusing to create a non-root zero TTL token") + return ErrInternalError + } + if err := c.tokenStore.create(ctx, &te); err != nil { c.logger.Error("failed to create token", "error", err) return ErrInternalError diff --git a/vault/token_store_test.go b/vault/token_store_test.go index 4b16409eb296..5c256c193ed6 100644 --- a/vault/token_store_test.go +++ b/vault/token_store_test.go @@ -55,11 +55,15 @@ func TestTokenStore_CubbyholeDeletion(t *testing.T) { Operation: logical.UpdateOperation, Path: "create", ClientToken: root, + Data: map[string]interface{}{ + "ttl": "600s", + }, } // Supplying token ID forces SHA1 hashing to be used if i%2 == 0 { tokenReq.Data = map[string]interface{}{ - "id": "testroot", + "id": "testroot", + "ttl": "600s", } } resp := testMakeTokenViaRequest(t, ts, tokenReq) @@ -111,6 +115,9 @@ func TestTokenStore_CubbyholeTidy(t *testing.T) { Operation: logical.UpdateOperation, Path: "create", ClientToken: root, + Data: map[string]interface{}{ + "ttl": "600s", + }, } resp := testMakeTokenViaRequest(t, ts, tokenReq) @@ -119,7 +126,8 @@ func TestTokenStore_CubbyholeTidy(t *testing.T) { // Supplying token ID forces SHA1 hashing to be used if i%3 == 0 { tokenReq.Data = map[string]interface{}{ - "id": "testroot", + "id": "testroot", + "ttl": "600s", } } @@ -545,10 +553,12 @@ func testMakeBatchTokenViaBackend(t testing.TB, ts *TokenStore, root, client, tt } func testMakeServiceTokenViaBackend(t testing.TB, ts *TokenStore, root, client, ttl string, policy []string) { + t.Helper() testMakeTokenViaBackend(t, ts, root, client, ttl, policy, false) } func testMakeTokenViaBackend(t testing.TB, ts *TokenStore, root, client, ttl string, policy []string, batch bool) { + t.Helper() req := logical.TestRequest(t, logical.UpdateOperation, "create") req.ClientToken = root if batch { @@ -566,6 +576,7 @@ func testMakeTokenViaBackend(t testing.TB, ts *TokenStore, root, client, ttl str } func testMakeTokenViaRequest(t testing.TB, ts *TokenStore, req *logical.Request) *logical.Response { + t.Helper() resp, err := ts.HandleRequest(namespace.RootContext(nil), req) if err != nil { t.Fatal(err) @@ -727,7 +738,7 @@ func TestTokenStore_HandleRequest_LookupAccessor(t *testing.T) { c, _, root := TestCoreUnsealed(t) ts := c.tokenStore - testMakeServiceTokenViaBackend(t, ts, root, "tokenid", "", []string{"foo"}) + testMakeServiceTokenViaBackend(t, ts, root, "tokenid", "60s", []string{"foo"}) out, err := ts.Lookup(namespace.RootContext(nil), "tokenid") if err != nil { t.Fatalf("err: %s", err) @@ -765,7 +776,7 @@ func TestTokenStore_HandleRequest_ListAccessors(t *testing.T) { testKeys := []string{"token1", "token2", "token3", "token4"} for _, key := range testKeys { - testMakeServiceTokenViaBackend(t, ts, root, key, "", []string{"foo"}) + testMakeServiceTokenViaBackend(t, ts, root, key, "60s", []string{"foo"}) } // Revoke root to make the number of accessors match @@ -2125,7 +2136,7 @@ func TestTokenStore_HandleRequest_Revoke(t *testing.T) { } root := rootToken.ID - testMakeServiceTokenViaBackend(t, ts, root, "child", "", []string{"root", "foo"}) + testMakeServiceTokenViaBackend(t, ts, root, "child", "60s", []string{"root", "foo"}) te, err := ts.Lookup(namespace.RootContext(nil), "child") if err != nil { @@ -2147,7 +2158,7 @@ func TestTokenStore_HandleRequest_Revoke(t *testing.T) { t.Fatalf("err: %v", err) } - testMakeServiceTokenViaBackend(t, ts, "child", "sub-child", "", []string{"foo"}) + testMakeServiceTokenViaBackend(t, ts, "child", "sub-child", "50s", []string{"foo"}) te, err = ts.Lookup(namespace.RootContext(nil), "sub-child") if err != nil { @@ -2201,8 +2212,8 @@ func TestTokenStore_HandleRequest_Revoke(t *testing.T) { } // Now test without registering the tokens through the expiration manager - testMakeServiceTokenViaBackend(t, ts, root, "child", "", []string{"root", "foo"}) - testMakeServiceTokenViaBackend(t, ts, "child", "sub-child", "", []string{"foo"}) + testMakeServiceTokenViaBackend(t, ts, root, "child", "60s", []string{"root", "foo"}) + testMakeServiceTokenViaBackend(t, ts, "child", "sub-child", "50s", []string{"foo"}) req = logical.TestRequest(t, logical.UpdateOperation, "revoke") req.Data = map[string]interface{}{ @@ -2239,8 +2250,8 @@ func TestTokenStore_HandleRequest_Revoke(t *testing.T) { func TestTokenStore_HandleRequest_RevokeOrphan(t *testing.T) { c, _, root := TestCoreUnsealed(t) ts := c.tokenStore - testMakeServiceTokenViaBackend(t, ts, root, "child", "", []string{"root", "foo"}) - testMakeServiceTokenViaBackend(t, ts, "child", "sub-child", "", []string{"foo"}) + testMakeServiceTokenViaBackend(t, ts, root, "child", "60s", []string{"root", "foo"}) + testMakeServiceTokenViaBackend(t, ts, "child", "sub-child", "50s", []string{"foo"}) req := logical.TestRequest(t, logical.UpdateOperation, "revoke-orphan") req.Data = map[string]interface{}{ @@ -2291,7 +2302,7 @@ func TestTokenStore_HandleRequest_RevokeOrphan(t *testing.T) { func TestTokenStore_HandleRequest_RevokeOrphan_NonRoot(t *testing.T) { c, _, root := TestCoreUnsealed(t) ts := c.tokenStore - testMakeServiceTokenViaBackend(t, ts, root, "child", "", []string{"foo"}) + testMakeServiceTokenViaBackend(t, ts, root, "child", "60s", []string{"foo"}) out, err := ts.Lookup(namespace.RootContext(nil), "child") if err != nil {