diff --git a/builtin/logical/transit/path_import.go b/builtin/logical/transit/path_import.go index ce04766d76ee..d1e34052f951 100644 --- a/builtin/logical/transit/path_import.go +++ b/builtin/logical/transit/path_import.go @@ -104,6 +104,28 @@ key.`, } } +func (b *backend) pathImportVersion() *framework.Path { + return &framework.Path{ + Pattern: "keys/" + framework.GenericNameRegex("name") + "/import_version", + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "The name of the key", + }, + "ciphertext": { + Type: framework.TypeString, + Description: `The base64-encoded ciphertext of the keys. The AES key should be encrypted using OAEP +with the wrapping key and then concatenated with the import key, wrapped by the AES key.`, + }, + }, + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: b.pathImportVersionWrite, + }, + HelpSynopsis: pathImportVersionWriteSyn, + HelpDescription: pathImportVersionWriteDesc, + } +} + func (b *backend) pathImportWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) derived := d.Get("derived").(bool) @@ -185,6 +207,48 @@ func (b *backend) pathImportWrite(ctx context.Context, req *logical.Request, d * return nil, nil } +func (b *backend) pathImportVersionWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + name := d.Get("name").(string) + ciphertextString := d.Get("ciphertext").(string) + + polReq := keysutil.PolicyRequest{ + Storage: req.Storage, + Name: name, + Upsert: false, + } + + p, _, err := b.GetPolicy(ctx, polReq, b.GetRandomReader()) + if err != nil { + return nil, err + } + if p == nil { + return nil, fmt.Errorf("no key found with name %s; to import a new key, use the import/ endpoint", name) + } + if !p.Imported { + return nil, errors.New("the import_version endpoint can only be used with an imported key") + } + if p.ConvergentEncryption { + return nil, errors.New("import_version cannot be used on keys with convergent encryption enabled") + } + + if !b.System().CachingDisabled() { + p.Lock(true) + } + defer p.Unlock() + + ciphertext, err := base64.RawURLEncoding.DecodeString(ciphertextString) + if err != nil { + return nil, err + } + importKey, err := b.decryptImportedKey(ctx, req.Storage, ciphertext) + err = p.Import(ctx, req.Storage, importKey, b.GetRandomReader()) + if err != nil { + return nil, err + } + + return nil, nil +} + func (b *backend) decryptImportedKey(ctx context.Context, storage logical.Storage, ciphertext []byte) ([]byte, error) { wrappedAESKey := ciphertext[:EncryptedKeyBytes] wrappedImportKey := ciphertext[EncryptedKeyBytes:] @@ -218,3 +282,6 @@ func (b *backend) decryptImportedKey(ctx context.Context, storage logical.Storag const pathImportWriteSyn = "Imports an externally-generated key into transit" const pathImportWriteDesc = "" + +const pathImportVersionWriteSyn = "" +const pathImportVersionWriteDesc = "" diff --git a/builtin/logical/transit/path_import_version..go b/builtin/logical/transit/path_import_version..go deleted file mode 100644 index 33f8e4693506..000000000000 --- a/builtin/logical/transit/path_import_version..go +++ /dev/null @@ -1,78 +0,0 @@ -package transit - -import ( - "context" - "encoding/base64" - "errors" - "fmt" - "github.com/hashicorp/vault/sdk/framework" - "github.com/hashicorp/vault/sdk/helper/keysutil" - "github.com/hashicorp/vault/sdk/logical" -) - -func (b *backend) pathImportVersion() *framework.Path { - return &framework.Path{ - Pattern: "keys/" + framework.GenericNameRegex("name") + "/import_version", - Fields: map[string]*framework.FieldSchema{ - "name": &framework.FieldSchema{ - Type: framework.TypeString, - Description: "The name of the key", - }, - "ciphertext": { - Type: framework.TypeString, - Description: `The base64-encoded ciphertext of the keys. The AES key should be encrypted using OAEP -with the wrapping key and then concatenated with the import key, wrapped by the AES key.`, - }, - }, - Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.UpdateOperation: b.pathImportVersionWrite, - }, - HelpSynopsis: pathImportVersionWriteSyn, - HelpDescription: pathImportVersionWriteDesc, - } -} - -func (b *backend) pathImportVersionWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - name := d.Get("name").(string) - ciphertextString := d.Get("ciphertext").(string) - - polReq := keysutil.PolicyRequest{ - Storage: req.Storage, - Name: name, - Upsert: false, - } - - p, _, err := b.GetPolicy(ctx, polReq, b.GetRandomReader()) - if err != nil { - return nil, err - } - if p == nil { - return nil, fmt.Errorf("no key found with name %s; to import a new key, use the import/ endpoint", name) - } - if !p.Imported { - return nil, errors.New("the import_version endpoint can only be used with an imported key") - } - if p.ConvergentEncryption { - return nil, errors.New("import_version cannot be used on keys with convergent encryption enabled") - } - - if !b.System().CachingDisabled() { - p.Lock(true) - } - defer p.Unlock() - - ciphertext, err := base64.RawURLEncoding.DecodeString(ciphertextString) - if err != nil { - return nil, err - } - importKey, err := b.decryptImportedKey(ctx, req.Storage, ciphertext) - err = p.Import(ctx, req.Storage, importKey, b.GetRandomReader()) - if err != nil { - return nil, err - } - - return nil, nil -} - -const pathImportVersionWriteSyn = "" -const pathImportVersionWriteDesc = "" diff --git a/sdk/helper/keysutil/policy.go b/sdk/helper/keysutil/policy.go index 07645521879d..59029756d2e1 100644 --- a/sdk/helper/keysutil/policy.go +++ b/sdk/helper/keysutil/policy.go @@ -1390,21 +1390,20 @@ func (p *Policy) Import(ctx context.Context, storage logical.Storage, key []byte switch parsedPrivateKey.(type) { case *ecdsa.PrivateKey: + if p.Type != KeyType_ECDSA_P256 && p.Type != KeyType_ECDSA_P384 && p.Type != KeyType_ECDSA_P521 { + return fmt.Errorf("invalid key type: expected %s, got %T", p.Type, parsedPrivateKey) + } + ecdsaKey := parsedPrivateKey.(*ecdsa.PrivateKey) - var curve elliptic.Curve - if p.Type == KeyType_ECDSA_P256 { - curve = elliptic.P256() - } else if p.Type == KeyType_ECDSA_P384 { + curve := elliptic.P256() + if p.Type == KeyType_ECDSA_P384 { curve = elliptic.P384() } else if p.Type == KeyType_ECDSA_P521 { curve = elliptic.P521() } - if curve == nil { - return fmt.Errorf("invalid key type: expected %s, got %T", p.Type, parsedPrivateKey) - } if ecdsaKey.Curve != curve { - return fmt.Errorf("invalid curve: expected %s, got %s", p.Type, curve) + return fmt.Errorf("invalid curve: expected %s, got %s", curve.Params().Name, ecdsaKey.Curve.Params().Name) } entry.EC_D = ecdsaKey.D @@ -1433,21 +1432,19 @@ func (p *Policy) Import(ctx context.Context, storage logical.Storage, key []byte publicKey := privateKey.Public().(ed25519.PublicKey) entry.FormattedPublicKey = base64.StdEncoding.EncodeToString(publicKey) case *rsa.PrivateKey: - var keyBits int - if p.Type == KeyType_RSA2048 { - keyBits = 2048 - } else if p.Type == KeyType_RSA3072 { - keyBits = 3072 - } else if p.Type == KeyType_RSA4096 { - keyBits = 4096 + if p.Type != KeyType_RSA2048 && p.Type != KeyType_RSA3072 && p.Type != KeyType_RSA4096 { + return fmt.Errorf("invalid key type: expected %s, got %T", p.Type, parsedPrivateKey) } - rsaKey := parsedPrivateKey.(*rsa.PrivateKey) - if keyBits == 0 { - return fmt.Errorf("invalid key type: expected %s, got %T", p.Type, rsaKey) + keyBytes := 256 + if p.Type == KeyType_RSA3072 { + keyBytes = 384 + } else if p.Type == KeyType_RSA4096 { + keyBytes = 512 } - if rsaKey.Size() != keyBits { - return fmt.Errorf("invalid key size: expected %s, got %d", p.Type, rsaKey.Size()) + rsaKey := parsedPrivateKey.(*rsa.PrivateKey) + if rsaKey.Size() != keyBytes { + return fmt.Errorf("invalid key size: expected %d bytes, got %d bytes", keyBytes, rsaKey.Size()) } entry.RSAKey = rsaKey @@ -1506,6 +1503,7 @@ func (p *Policy) Rotate(ctx context.Context, storage logical.Storage, randReader return err } + p.Imported = false return p.Persist(ctx, storage) } diff --git a/sdk/helper/keysutil/policy_test.go b/sdk/helper/keysutil/policy_test.go index e1ad6dde7c3d..d8212773641a 100644 --- a/sdk/helper/keysutil/policy_test.go +++ b/sdk/helper/keysutil/policy_test.go @@ -3,7 +3,12 @@ package keysutil import ( "bytes" "context" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" + "crypto/rsa" + "crypto/x509" + "golang.org/x/crypto/ed25519" "reflect" "strconv" "sync" @@ -615,6 +620,113 @@ func Test_BadArchive(t *testing.T) { } } +func Test_Import(t *testing.T) { + ctx := context.Background() + storage := &logical.InmemStorage{} + testKeys, err := generateTestKeys() + if err != nil { + t.Fatalf("error generating test keys: %s", err) + } + + tests := map[string]struct { + policy Policy + key []byte + shouldError bool + }{ + "import AES key": { + policy: Policy{ + Name: "test-aes-key", + Type: KeyType_AES256_GCM96, + }, + key: testKeys[KeyType_AES256_GCM96], + shouldError: false, + }, + "import RSA key": { + policy: Policy{ + Name: "test-rsa-key", + Type: KeyType_RSA2048, + }, + key: testKeys[KeyType_RSA2048], + shouldError: false, + }, + "import ECDSA key": { + policy: Policy{ + Name: "test-ecdsa-key", + Type: KeyType_ECDSA_P256, + }, + key: testKeys[KeyType_ECDSA_P256], + shouldError: false, + }, + "import ED25519 key": { + policy: Policy{ + Name: "test-ed25519-key", + Type: KeyType_ED25519, + }, + key: testKeys[KeyType_ED25519], + shouldError: false, + }, + "import incorrect key type": { + policy: Policy{ + Name: "test-ed25519-key", + Type: KeyType_ED25519, + }, + key: testKeys[KeyType_AES256_GCM96], + shouldError: true, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if err := test.policy.Import(ctx, storage, test.key, rand.Reader); (err != nil) != test.shouldError { + t.Fatalf("error importing key: %s", err) + } + }) + } +} + +func generateTestKeys() (map[KeyType][]byte, error) { + keyMap := make(map[KeyType][]byte) + + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + rsaKeyBytes, err := x509.MarshalPKCS8PrivateKey(rsaKey) + if err != nil { + return nil, err + } + keyMap[KeyType_RSA2048] = rsaKeyBytes + + ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, err + } + ecdsaKeyBytes, err := x509.MarshalPKCS8PrivateKey(ecdsaKey) + if err != nil { + return nil, err + } + keyMap[KeyType_ECDSA_P256] = ecdsaKeyBytes + + _, ed25519Key, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, err + } + ed25519KeyBytes, err := x509.MarshalPKCS8PrivateKey(ed25519Key) + if err != nil { + return nil, err + } + keyMap[KeyType_ED25519] = ed25519KeyBytes + + aesKey := make([]byte, 32) + _, err = rand.Read(aesKey) + if err != nil { + return nil, err + } + keyMap[KeyType_AES256_GCM96] = aesKey + + return keyMap, nil +} + func BenchmarkSymmetric(b *testing.B) { ctx := context.Background() lm, _ := NewLockManager(true, 0)