Skip to content

Commit

Permalink
Avoid map iteration induced flakiness in seal wrapper key searching (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sgmiller authored and Monkeychip committed Feb 5, 2024
1 parent 4f042f1 commit 583908c
Showing 1 changed file with 21 additions and 27 deletions.
48 changes: 21 additions & 27 deletions vault/seal/seal.go
Original file line number Diff line number Diff line change
Expand Up @@ -673,8 +673,6 @@ func (a *access) tryEncrypt(ctx context.Context, sealWrapper *SealWrapper, plain
// Returns the plaintext, a flag indicating whether the ciphertext is up-to-date
// (according to IsUpToDate), and an error.
func (a *access) Decrypt(ctx context.Context, ciphertext *MultiWrapValue, options ...wrapping.Option) ([]byte, bool, error) {
blobInfoMap := slotsByKeyId(ciphertext)

isUpToDate, err := a.IsUpToDate(ctx, ciphertext, false)
if err != nil {
return nil, false, err
Expand Down Expand Up @@ -719,7 +717,7 @@ func (a *access) Decrypt(ctx context.Context, ciphertext *MultiWrapValue, option
}

decrypt := func(sealWrapper *SealWrapper) {
pt, oldKey, err := a.tryDecrypt(ctx, sealWrapper, blobInfoMap, options)
pt, oldKey, err := a.tryDecrypt(ctx, sealWrapper, ciphertext, options)
reportResult(sealWrapper.Name, pt, oldKey, err)
}

Expand All @@ -731,21 +729,20 @@ outer:
// and ensures we'll use it first. This should equal the highest priority wrapper in the nominal
// case, but may not if a seal is unhealthy. This ensures we try the highest priority healthy
// seal first if available, and warn if we don't think we have one in common.
for k := range blobInfoMap {
for _, sealWrapper := range wrappersByPriority {
keyId, err := sealWrapper.Wrapper.KeyId(ctx)
if err != nil {
resultWg.Add(1)
go reportResult(sealWrapper.Name, nil, false, err)
continue
}
if keyId == k {
found = true
first = sealWrapper
break outer
}
for _, sealWrapper := range wrappersByPriority {
keyId, err := sealWrapper.Wrapper.KeyId(ctx)
if err != nil {
resultWg.Add(1)
go reportResult(sealWrapper.Name, nil, false, err)
continue
}
if bi := ciphertext.BlobInfoForKeyId(keyId); bi != nil {
found = true
first = sealWrapper
break outer
}
}

if !found {
a.logger.Warn("while unwrapping, value has no key-id in common with currently healthy seals. Trying all healthy seals")
}
Expand Down Expand Up @@ -800,7 +797,7 @@ GATHER_RESULTS:

// tryDecrypt returns the plaintext and a flag indicating whether the decryption was done by the "unwrapSeal" (see
// sealWrapMigration.Decrypt).
func (a *access) tryDecrypt(ctx context.Context, sealWrapper *SealWrapper, ciphertextByKeyId map[string]*wrapping.BlobInfo, options []wrapping.Option) ([]byte, bool, error) {
func (a *access) tryDecrypt(ctx context.Context, sealWrapper *SealWrapper, value *MultiWrapValue, options []wrapping.Option) ([]byte, bool, error) {
now := time.Now()
var decryptErr error
mLabels := []metrics.Label{{Name: "seal_wrapper_name", Value: sealWrapper.Name}}
Expand All @@ -821,15 +818,15 @@ func (a *access) tryDecrypt(ctx context.Context, sealWrapper *SealWrapper, ciphe
var keyId string
if id, err := sealWrapper.Wrapper.KeyId(ctx); err == nil {
keyId = id
if ciphertext, ok := ciphertextByKeyId[keyId]; ok {
if ciphertext := value.BlobInfoForKeyId(keyId); ciphertext != nil {
pt, decryptErr = sealWrapper.Wrapper.Decrypt(ctx, ciphertext, options...)

sealWrapper.SetHealthy(decryptErr == nil || IsOldKeyError(decryptErr), now)
}
}
// If we don't get a result, try all the slots
if pt == nil && decryptErr == nil {
for _, ciphertext := range ciphertextByKeyId {
for _, ciphertext := range value.Slots {
pt, decryptErr = sealWrapper.Wrapper.Decrypt(ctx, ciphertext, options...)
if decryptErr == nil {
// Note that we only update wrapper health for failures on exact key ID match,
Expand Down Expand Up @@ -914,16 +911,13 @@ func (a *access) GetShamirKeyBytes(ctx context.Context) ([]byte, error) {
return shamirWrapper.KeyBytes(ctx)
}

func slotsByKeyId(value *MultiWrapValue) map[string]*wrapping.BlobInfo {
ret := make(map[string]*wrapping.BlobInfo)
for _, blobInfo := range value.Slots {
keyId := ""
if blobInfo.KeyInfo != nil {
keyId = blobInfo.KeyInfo.KeyId
func (v *MultiWrapValue) BlobInfoForKeyId(keyId string) *wrapping.BlobInfo {
for _, blobInfo := range v.Slots {
if blobInfo.KeyInfo != nil && blobInfo.KeyInfo.KeyId == keyId {
return blobInfo
}
ret[keyId] = blobInfo
}
return ret
return nil
}

type keyIdSet struct {
Expand Down

0 comments on commit 583908c

Please sign in to comment.