Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time-based transit key autorotation #13691

Merged
merged 11 commits into from
Jan 20, 2022
105 changes: 100 additions & 5 deletions builtin/logical/transit/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@ import (
"context"
"fmt"
"io"
"strconv"
"strings"
"sync"
"time"

"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/keysutil"
"github.com/hashicorp/vault/sdk/logical"
)
Expand Down Expand Up @@ -59,9 +63,10 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (*backend, error)
b.pathCacheConfig(),
},

Secrets: []*framework.Secret{},
Invalidate: b.invalidate,
BackendType: logical.TypeLogical,
Secrets: []*framework.Secret{},
Invalidate: b.invalidate,
BackendType: logical.TypeLogical,
PeriodicFunc: b.periodicFunc,
}

// determine cacheSize to use. Defaults to 0 which means unlimited
Expand Down Expand Up @@ -93,8 +98,10 @@ type backend struct {
*framework.Backend
lm *keysutil.LockManager
// Lock to make changes to any of the backend's cache configuration.
configMutex sync.RWMutex
cacheSizeChanged bool
configMutex sync.RWMutex
cacheSizeChanged bool
checkAutoRotateAfter time.Time
autoRotateOnce sync.Once
}

func GetCacheSizeFromStorage(ctx context.Context, s logical.Storage) (int, error) {
Expand Down Expand Up @@ -162,3 +169,91 @@ func (b *backend) invalidate(ctx context.Context, key string) {
b.cacheSizeChanged = true
}
}

// periodicFunc is a central collection of functions that run on an interval.
// Anything that should be called regularly can be placed within this method.
func (b *backend) periodicFunc(ctx context.Context, req *logical.Request) error {
// These operations ensure the auto-rotate only happens once simultaneously. It's an unlikely edge
// given the time scale, but a safeguard nonetheless.
var err error
didAutoRotate := false
autoRotateOnceFn := func() {
err = b.autoRotateKeys(ctx, req)
didAutoRotate = true
}
b.autoRotateOnce.Do(autoRotateOnceFn)
if didAutoRotate {
b.autoRotateOnce = sync.Once{}
}

return err
}

// autoRotateKeys retrieves all transit keys and rotates those which have an
// auto rotate interval defined which has passed. This operation only happens
// on primary nodes and performance secondary nodes which have a local mount.
func (b *backend) autoRotateKeys(ctx context.Context, req *logical.Request) error {
// Only check for autorotation once an hour to avoid unnecessarily iterating
// over all keys too frequently.
if time.Now().Before(b.checkAutoRotateAfter) {
return nil
}
b.checkAutoRotateAfter = time.Now().Add(1 * time.Hour)

// Early exit if not a primary or performance secondary with a local mount.
if b.System().ReplicationState().HasState(consts.ReplicationDRSecondary|consts.ReplicationPerformanceStandby) ||
(!b.System().LocalMount() && b.System().ReplicationState().HasState(consts.ReplicationPerformanceSecondary)) {
return nil
}

// Retrieve all keys and loop over them to check if they need to be rotated.
keys, err := req.Storage.List(ctx, "policy/")
if err != nil {
return err
}

// Collect errors in a multierror to ensure a single failure doesn't prevent
// all keys from being rotated.
var errs *multierror.Error

for _, key := range keys {
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: key,
}, b.GetRandomReader())
if err != nil {
errs = multierror.Append(errs, err)
continue
}

// If the policy is nil, move onto the next one.
if p == nil {
continue
}

schultz-is marked this conversation as resolved.
Show resolved Hide resolved
// If the policy's automatic rotation interval is 0, it should not
// automatically rotate.
if p.AutoRotateInterval == 0 {
continue
}

// Retrieve the latest version of the policy and determine if it is time to rotate.
latestKey := p.Keys[strconv.Itoa(p.LatestVersion)]
if time.Now().After(latestKey.CreationTime.Add(p.AutoRotateInterval)) {
if b.Logger().IsDebug() {
b.Logger().Debug("automatically rotating key", "key", key)
}
if !b.System().CachingDisabled() {
p.Lock(true)
}
err = p.Rotate(ctx, req.Storage, b.GetRandomReader())
p.Unlock()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this panic if caching is disabled? Unlocking a not locked mutex, i believe, panics

if err != nil {
errs = multierror.Append(errs, err)
continue
}
}
}

return errs.ErrorOrNil()
}
188 changes: 188 additions & 0 deletions builtin/logical/transit/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
uuid "github.com/hashicorp/go-uuid"
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/keysutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/mitchellh/mapstructure"
Expand Down Expand Up @@ -1519,3 +1520,190 @@ func TestBadInput(t *testing.T) {
t.Fatal("expected error")
}
}

func TestTransit_AutoRotateKeys(t *testing.T) {
tests := map[string]struct {
isDRSecondary bool
isPerfSecondary bool
isStandby bool
isLocal bool
shouldRotate bool
}{
"primary, no local mount": {
shouldRotate: true,
},
"DR secondary, no local mount": {
isDRSecondary: true,
shouldRotate: false,
},
"perf standby, no local mount": {
isStandby: true,
shouldRotate: false,
},
"perf secondary, no local mount": {
isPerfSecondary: true,
shouldRotate: false,
},
"perf secondary, local mount": {
isPerfSecondary: true,
isLocal: true,
shouldRotate: true,
},
}

for name, test := range tests {
t.Run(
name,
func(t *testing.T) {
var repState consts.ReplicationState
if test.isDRSecondary {
repState.AddState(consts.ReplicationDRSecondary)
}
if test.isPerfSecondary {
repState.AddState(consts.ReplicationPerformanceSecondary)
}
if test.isStandby {
repState.AddState(consts.ReplicationPerformanceStandby)
}

sysView := logical.TestSystemView()
sysView.ReplicationStateVal = repState
sysView.LocalMountVal = test.isLocal

storage := &logical.InmemStorage{}

conf := &logical.BackendConfig{
StorageView: storage,
System: sysView,
}

b, _ := Backend(context.Background(), conf)
if b == nil {
t.Fatal("failed to create backend")
}

err := b.Backend.Setup(context.Background(), conf)
if err != nil {
t.Fatal(err)
}

// Write a key with the default auto rotate value (0/disabled)
req := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/test1",
}
resp, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
if resp != nil {
t.Fatal("expected nil response")
}

// Write a key with an auto rotate value one day in the future
req = &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/test2",
Data: map[string]interface{}{
"auto_rotate_interval": 24 * time.Hour,
},
}
resp, err = b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
if resp != nil {
t.Fatal("expected nil response")
}

// Run the rotation check and ensure none of the keys have rotated
b.checkAutoRotateAfter = time.Now()
if err = b.autoRotateKeys(context.Background(), &logical.Request{Storage: storage}); err != nil {
t.Fatal(err)
}
req = &logical.Request{
Storage: storage,
Operation: logical.ReadOperation,
Path: "keys/test1",
}
resp, err = b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
if resp.Data["latest_version"] != 1 {
t.Fatalf("incorrect latest_version found, got: %d, want: %d", resp.Data["latest_version"], 1)
}

req.Path = "keys/test2"
resp, err = b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
if resp.Data["latest_version"] != 1 {
t.Fatalf("incorrect latest_version found, got: %d, want: %d", resp.Data["latest_version"], 1)
}

// Update auto rotate interval on one key to be one nanosecond
p, _, err := b.GetPolicy(context.Background(), keysutil.PolicyRequest{
Storage: storage,
Name: "test2",
}, b.GetRandomReader())
if err != nil {
t.Fatal(err)
}
if p == nil {
t.Fatal("expected non-nil policy")
}
p.AutoRotateInterval = time.Nanosecond
err = p.Persist(context.Background(), storage)
if err != nil {
t.Fatal(err)
}

// Run the rotation check and validate the state of key rotations
b.checkAutoRotateAfter = time.Now()
if err = b.autoRotateKeys(context.Background(), &logical.Request{Storage: storage}); err != nil {
t.Fatal(err)
}
req = &logical.Request{
Storage: storage,
Operation: logical.ReadOperation,
Path: "keys/test1",
}
resp, err = b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
if resp.Data["latest_version"] != 1 {
t.Fatalf("incorrect latest_version found, got: %d, want: %d", resp.Data["latest_version"], 1)
}
req.Path = "keys/test2"
resp, err = b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
expectedVersion := 1
if test.shouldRotate {
expectedVersion = 2
}
if resp.Data["latest_version"] != expectedVersion {
t.Fatalf("incorrect latest_version found, got: %d, want: %d", resp.Data["latest_version"], expectedVersion)
}
},
)
}
}
25 changes: 25 additions & 0 deletions builtin/logical/transit/path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package transit
import (
"context"
"fmt"
"time"

"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/keysutil"
Expand Down Expand Up @@ -47,6 +48,13 @@ the latest version of the key is allowed.`,
Type: framework.TypeBool,
Description: `Enables taking a backup of the named key in plaintext format. Once set, this cannot be disabled.`,
},

"auto_rotate_interval": {
Type: framework.TypeDurationSecond,
Description: `Amount of time the key should live before
being automatically rotated. A value of 0
disables automatic rotation for the key.`,
},
},

Callbacks: map[logical.Operation]framework.OperationFunc{
Expand Down Expand Up @@ -185,6 +193,23 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *
}
}

autoRotateIntervalRaw, ok, err := d.GetOkErr("auto_rotate_interval")
if err != nil {
return nil, err
}
if ok {
autoRotateInterval := time.Second * time.Duration(autoRotateIntervalRaw.(int))
// Provided value must be 0 to disable or at least an hour
if autoRotateInterval != 0 && autoRotateInterval < time.Hour {
return logical.ErrorResponse("auto rotate interval must be 0 to disable or at least an hour"), nil
}

if autoRotateInterval != p.AutoRotateInterval {
p.AutoRotateInterval = autoRotateInterval
persistNeeded = true
}
}

if !persistNeeded {
return nil, nil
}
Expand Down
Loading