diff --git a/vault/expiration.go b/vault/expiration.go index 81ed2c99ca9d..492396c9b24b 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -599,9 +599,18 @@ func (m *ExpirationManager) revokePrefixCommon(prefix string, force bool) error defer m.restoreRequestLock.Unlock() } - // Ensure there is a trailing slash + // Ensure there is a trailing slash; or, if there is no slash, see if there + // is a matching specific ID if !strings.HasSuffix(prefix, "/") { - prefix = prefix + "/" + le, err := m.loadEntry(prefix) + if err == nil && le != nil { + if err := m.revokeCommon(prefix, force, false); err != nil { + return errwrap.Wrapf(fmt.Sprintf("failed to revoke %q: {{err}}", prefix), err) + } + return nil + } else { + prefix = prefix + "/" + } } // Accumulate existing leases diff --git a/vault/expiration_test.go b/vault/expiration_test.go index 77b7febff2b7..285d446dead8 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -1546,6 +1546,78 @@ func TestExpiration_RevokeForce(t *testing.T) { } } +func TestExpiration_RevokeForceSingle(t *testing.T) { + core, _, _, root := TestCoreWithTokenStore(t) + + core.logicalBackends["badrenew"] = badRenewFactory + me := &MountEntry{ + Table: mountTableType, + Path: "badrenew/", + Type: "badrenew", + Accessor: "badrenewaccessor", + } + + err := core.mount(context.Background(), me) + if err != nil { + t.Fatal(err) + } + + req := &logical.Request{ + Operation: logical.ReadOperation, + Path: "badrenew/creds", + ClientToken: root, + } + + resp, err := core.HandleRequest(req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("response was nil") + } + if resp.Secret == nil { + t.Fatalf("response secret was nil, response was %#v", *resp) + } + leaseID := resp.Secret.LeaseID + + req.Operation = logical.UpdateOperation + req.Path = "sys/leases/lookup" + req.Data = map[string]interface{}{"lease_id": leaseID} + resp, err = core.HandleRequest(req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + if resp.Data["id"].(string) != leaseID { + t.Fatalf("expected id %q, got %q", leaseID, resp.Data["id"].(string)) + } + + req.Path = "sys/revoke-prefix/" + leaseID + + resp, err = core.HandleRequest(req) + if err == nil { + t.Fatal("expected error") + } + + req.Path = "sys/revoke-force/" + leaseID + resp, err = core.HandleRequest(req) + if err != nil { + t.Fatalf("got error: %s", err) + } + + req.Path = "sys/leases/lookup" + req.Data = map[string]interface{}{"lease_id": leaseID} + resp, err = core.HandleRequest(req) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "invalid request") { + t.Fatalf("bad error: %v", err) + } +} + func badRenewFactory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { be := &framework.Backend{ Paths: []*framework.Path{