Skip to content

Commit

Permalink
refactor import endpoints and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rculpepper committed May 2, 2022
1 parent 5492c0a commit cfd87a2
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 98 deletions.
67 changes: 67 additions & 0 deletions builtin/logical/transit/path_import.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,28 @@ key.`,
}
}

func (b *backend) pathImportVersion() *framework.Path {
return &framework.Path{
Pattern: "keys/" + framework.GenericNameRegex("name") + "/import_version",
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "The name of the key",
},
"ciphertext": {
Type: framework.TypeString,
Description: `The base64-encoded ciphertext of the keys. The AES key should be encrypted using OAEP
with the wrapping key and then concatenated with the import key, wrapped by the AES key.`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.pathImportVersionWrite,
},
HelpSynopsis: pathImportVersionWriteSyn,
HelpDescription: pathImportVersionWriteDesc,
}
}

func (b *backend) pathImportWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
derived := d.Get("derived").(bool)
Expand Down Expand Up @@ -185,6 +207,48 @@ func (b *backend) pathImportWrite(ctx context.Context, req *logical.Request, d *
return nil, nil
}

func (b *backend) pathImportVersionWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
ciphertextString := d.Get("ciphertext").(string)

polReq := keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
Upsert: false,
}

p, _, err := b.GetPolicy(ctx, polReq, b.GetRandomReader())
if err != nil {
return nil, err
}
if p == nil {
return nil, fmt.Errorf("no key found with name %s; to import a new key, use the import/ endpoint", name)
}
if !p.Imported {
return nil, errors.New("the import_version endpoint can only be used with an imported key")
}
if p.ConvergentEncryption {
return nil, errors.New("import_version cannot be used on keys with convergent encryption enabled")
}

if !b.System().CachingDisabled() {
p.Lock(true)
}
defer p.Unlock()

ciphertext, err := base64.RawURLEncoding.DecodeString(ciphertextString)
if err != nil {
return nil, err
}
importKey, err := b.decryptImportedKey(ctx, req.Storage, ciphertext)
err = p.Import(ctx, req.Storage, importKey, b.GetRandomReader())
if err != nil {
return nil, err
}

return nil, nil
}

func (b *backend) decryptImportedKey(ctx context.Context, storage logical.Storage, ciphertext []byte) ([]byte, error) {
wrappedAESKey := ciphertext[:EncryptedKeyBytes]
wrappedImportKey := ciphertext[EncryptedKeyBytes:]
Expand Down Expand Up @@ -218,3 +282,6 @@ func (b *backend) decryptImportedKey(ctx context.Context, storage logical.Storag

const pathImportWriteSyn = "Imports an externally-generated key into transit"
const pathImportWriteDesc = ""

const pathImportVersionWriteSyn = ""
const pathImportVersionWriteDesc = ""
78 changes: 0 additions & 78 deletions builtin/logical/transit/path_import_version..go

This file was deleted.

38 changes: 18 additions & 20 deletions sdk/helper/keysutil/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -1390,21 +1390,20 @@ func (p *Policy) Import(ctx context.Context, storage logical.Storage, key []byte

switch parsedPrivateKey.(type) {
case *ecdsa.PrivateKey:
if p.Type != KeyType_ECDSA_P256 && p.Type != KeyType_ECDSA_P384 && p.Type != KeyType_ECDSA_P521 {
return fmt.Errorf("invalid key type: expected %s, got %T", p.Type, parsedPrivateKey)
}

ecdsaKey := parsedPrivateKey.(*ecdsa.PrivateKey)
var curve elliptic.Curve
if p.Type == KeyType_ECDSA_P256 {
curve = elliptic.P256()
} else if p.Type == KeyType_ECDSA_P384 {
curve := elliptic.P256()
if p.Type == KeyType_ECDSA_P384 {
curve = elliptic.P384()
} else if p.Type == KeyType_ECDSA_P521 {
curve = elliptic.P521()
}

if curve == nil {
return fmt.Errorf("invalid key type: expected %s, got %T", p.Type, parsedPrivateKey)
}
if ecdsaKey.Curve != curve {
return fmt.Errorf("invalid curve: expected %s, got %s", p.Type, curve)
return fmt.Errorf("invalid curve: expected %s, got %s", curve.Params().Name, ecdsaKey.Curve.Params().Name)
}

entry.EC_D = ecdsaKey.D
Expand Down Expand Up @@ -1433,21 +1432,19 @@ func (p *Policy) Import(ctx context.Context, storage logical.Storage, key []byte
publicKey := privateKey.Public().(ed25519.PublicKey)
entry.FormattedPublicKey = base64.StdEncoding.EncodeToString(publicKey)
case *rsa.PrivateKey:
var keyBits int
if p.Type == KeyType_RSA2048 {
keyBits = 2048
} else if p.Type == KeyType_RSA3072 {
keyBits = 3072
} else if p.Type == KeyType_RSA4096 {
keyBits = 4096
if p.Type != KeyType_RSA2048 && p.Type != KeyType_RSA3072 && p.Type != KeyType_RSA4096 {
return fmt.Errorf("invalid key type: expected %s, got %T", p.Type, parsedPrivateKey)
}
rsaKey := parsedPrivateKey.(*rsa.PrivateKey)

if keyBits == 0 {
return fmt.Errorf("invalid key type: expected %s, got %T", p.Type, rsaKey)
keyBytes := 256
if p.Type == KeyType_RSA3072 {
keyBytes = 384
} else if p.Type == KeyType_RSA4096 {
keyBytes = 512
}
if rsaKey.Size() != keyBits {
return fmt.Errorf("invalid key size: expected %s, got %d", p.Type, rsaKey.Size())
rsaKey := parsedPrivateKey.(*rsa.PrivateKey)
if rsaKey.Size() != keyBytes {
return fmt.Errorf("invalid key size: expected %d bytes, got %d bytes", keyBytes, rsaKey.Size())
}

entry.RSAKey = rsaKey
Expand Down Expand Up @@ -1506,6 +1503,7 @@ func (p *Policy) Rotate(ctx context.Context, storage logical.Storage, randReader
return err
}

p.Imported = false
return p.Persist(ctx, storage)
}

Expand Down
112 changes: 112 additions & 0 deletions sdk/helper/keysutil/policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@ package keysutil
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"golang.org/x/crypto/ed25519"
"reflect"
"strconv"
"sync"
Expand Down Expand Up @@ -615,6 +620,113 @@ func Test_BadArchive(t *testing.T) {
}
}

func Test_Import(t *testing.T) {
ctx := context.Background()
storage := &logical.InmemStorage{}
testKeys, err := generateTestKeys()
if err != nil {
t.Fatalf("error generating test keys: %s", err)
}

tests := map[string]struct {
policy Policy
key []byte
shouldError bool
}{
"import AES key": {
policy: Policy{
Name: "test-aes-key",
Type: KeyType_AES256_GCM96,
},
key: testKeys[KeyType_AES256_GCM96],
shouldError: false,
},
"import RSA key": {
policy: Policy{
Name: "test-rsa-key",
Type: KeyType_RSA2048,
},
key: testKeys[KeyType_RSA2048],
shouldError: false,
},
"import ECDSA key": {
policy: Policy{
Name: "test-ecdsa-key",
Type: KeyType_ECDSA_P256,
},
key: testKeys[KeyType_ECDSA_P256],
shouldError: false,
},
"import ED25519 key": {
policy: Policy{
Name: "test-ed25519-key",
Type: KeyType_ED25519,
},
key: testKeys[KeyType_ED25519],
shouldError: false,
},
"import incorrect key type": {
policy: Policy{
Name: "test-ed25519-key",
Type: KeyType_ED25519,
},
key: testKeys[KeyType_AES256_GCM96],
shouldError: true,
},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
if err := test.policy.Import(ctx, storage, test.key, rand.Reader); (err != nil) != test.shouldError {
t.Fatalf("error importing key: %s", err)
}
})
}
}

func generateTestKeys() (map[KeyType][]byte, error) {
keyMap := make(map[KeyType][]byte)

rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
rsaKeyBytes, err := x509.MarshalPKCS8PrivateKey(rsaKey)
if err != nil {
return nil, err
}
keyMap[KeyType_RSA2048] = rsaKeyBytes

ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}
ecdsaKeyBytes, err := x509.MarshalPKCS8PrivateKey(ecdsaKey)
if err != nil {
return nil, err
}
keyMap[KeyType_ECDSA_P256] = ecdsaKeyBytes

_, ed25519Key, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
ed25519KeyBytes, err := x509.MarshalPKCS8PrivateKey(ed25519Key)
if err != nil {
return nil, err
}
keyMap[KeyType_ED25519] = ed25519KeyBytes

aesKey := make([]byte, 32)
_, err = rand.Read(aesKey)
if err != nil {
return nil, err
}
keyMap[KeyType_AES256_GCM96] = aesKey

return keyMap, nil
}

func BenchmarkSymmetric(b *testing.B) {
ctx := context.Background()
lm, _ := NewLockManager(true, 0)
Expand Down

0 comments on commit cfd87a2

Please sign in to comment.