diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index 315200fdabcb..71f94744db3d 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -38,6 +38,191 @@ func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) { return b, config.StorageView } +func TestTransit_RSA(t *testing.T) { + testTransit_RSA(t, "rsa-2048") + testTransit_RSA(t, "rsa-4096") +} + +func testTransit_RSA(t *testing.T, keyType string) { + var resp *logical.Response + var err error + b, storage := createBackendWithStorage(t) + + keyReq := &logical.Request{ + Path: "keys/rsa", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "type": keyType, + }, + Storage: storage, + } + + resp, err = b.HandleRequest(keyReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + + plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" // "the quick brown fox" + + encryptReq := &logical.Request{ + Path: "encrypt/rsa", + Operation: logical.UpdateOperation, + Storage: storage, + Data: map[string]interface{}{ + "plaintext": plaintext, + }, + } + + resp, err = b.HandleRequest(encryptReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + + ciphertext1 := resp.Data["ciphertext"].(string) + + decryptReq := &logical.Request{ + Path: "decrypt/rsa", + Operation: logical.UpdateOperation, + Storage: storage, + Data: map[string]interface{}{ + "ciphertext": ciphertext1, + }, + } + + resp, err = b.HandleRequest(decryptReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + + decryptedPlaintext := resp.Data["plaintext"] + + if plaintext != decryptedPlaintext { + t.Fatalf("bad: plaintext; expected: %q\nactual: %q", plaintext, decryptedPlaintext) + } + + // Rotate the key + rotateReq := &logical.Request{ + Path: "keys/rsa/rotate", + Operation: logical.UpdateOperation, + Storage: storage, + } + resp, err = b.HandleRequest(rotateReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + + // Encrypt again + resp, err = b.HandleRequest(encryptReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + ciphertext2 := resp.Data["ciphertext"].(string) + + if ciphertext1 == ciphertext2 { + t.Fatalf("expected different ciphertexts") + } + + // See if the older ciphertext can still be decrypted + resp, err = b.HandleRequest(decryptReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + if resp.Data["plaintext"].(string) != plaintext { + t.Fatal("failed to decrypt old ciphertext after rotating the key") + } + + // Decrypt the new ciphertext + decryptReq.Data = map[string]interface{}{ + "ciphertext": ciphertext2, + } + resp, err = b.HandleRequest(decryptReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + if resp.Data["plaintext"].(string) != plaintext { + t.Fatal("failed to decrypt ciphertext after rotating the key") + } + + signReq := &logical.Request{ + Path: "sign/rsa", + Operation: logical.UpdateOperation, + Storage: storage, + Data: map[string]interface{}{ + "input": plaintext, + }, + } + resp, err = b.HandleRequest(signReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + signature := resp.Data["signature"].(string) + + verifyReq := &logical.Request{ + Path: "verify/rsa", + Operation: logical.UpdateOperation, + Storage: storage, + Data: map[string]interface{}{ + "input": plaintext, + "signature": signature, + }, + } + + resp, err = b.HandleRequest(verifyReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + if !resp.Data["valid"].(bool) { + t.Fatalf("failed to verify the RSA signature") + } + + signReq.Data = map[string]interface{}{ + "input": plaintext, + "algorithm": "invalid", + } + resp, err = b.HandleRequest(signReq) + if err != nil { + t.Fatal(err) + } + if resp == nil || !resp.IsError() { + t.Fatal("expected an error response") + } + + signReq.Data = map[string]interface{}{ + "input": plaintext, + "algorithm": "sha2-512", + } + resp, err = b.HandleRequest(signReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + signature = resp.Data["signature"].(string) + + verifyReq.Data = map[string]interface{}{ + "input": plaintext, + "signature": signature, + } + resp, err = b.HandleRequest(verifyReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + if resp.Data["valid"].(bool) { + t.Fatalf("expected validation to fail") + } + + verifyReq.Data = map[string]interface{}{ + "input": plaintext, + "signature": signature, + "algorithm": "sha2-512", + } + resp, err = b.HandleRequest(verifyReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + if !resp.Data["valid"].(bool) { + t.Fatalf("failed to verify the RSA signature") + } +} + func TestBackend_basic(t *testing.T) { decryptData := make(map[string]interface{}) logicaltest.Test(t, logicaltest.TestCase{ diff --git a/builtin/logical/transit/path_encrypt_test.go b/builtin/logical/transit/path_encrypt_test.go index 6ab20db27130..d5866004535e 100644 --- a/builtin/logical/transit/path_encrypt_test.go +++ b/builtin/logical/transit/path_encrypt_test.go @@ -26,7 +26,7 @@ func TestTransit_BatchEncryptionCase1(t *testing.T) { t.Fatalf("err:%v resp:%#v", err, resp) } - plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" + plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" // "the quick brown fox" encData := map[string]interface{}{ "plaintext": plaintext, diff --git a/builtin/logical/transit/path_export.go b/builtin/logical/transit/path_export.go index a18db91b0f59..a218c22d8231 100644 --- a/builtin/logical/transit/path_export.go +++ b/builtin/logical/transit/path_export.go @@ -3,6 +3,7 @@ package transit import ( "crypto/ecdsa" "crypto/elliptic" + "crypto/rsa" "crypto/x509" "encoding/base64" "encoding/pem" @@ -152,6 +153,9 @@ func getExportKey(policy *keysutil.Policy, key *keysutil.KeyEntry, exportType st switch policy.Type { case keysutil.KeyType_AES256_GCM96: return strings.TrimSpace(base64.StdEncoding.EncodeToString(key.Key)), nil + + case keysutil.KeyType_RSA2048, keysutil.KeyType_RSA4096: + return encodeRSAPrivateKey(key.RSAKey), nil } case exportTypeSigningKey: @@ -165,12 +169,27 @@ func getExportKey(policy *keysutil.Policy, key *keysutil.KeyEntry, exportType st case keysutil.KeyType_ED25519: return strings.TrimSpace(base64.StdEncoding.EncodeToString(key.Key)), nil + + case keysutil.KeyType_RSA2048, keysutil.KeyType_RSA4096: + return encodeRSAPrivateKey(key.RSAKey), nil } } return "", fmt.Errorf("unknown key type %v", policy.Type) } +func encodeRSAPrivateKey(key *rsa.PrivateKey) string { + // When encoding PKCS1, the PEM header should be `RSA PRIVATE KEY`. When Go + // has PKCS8 encoding support, we may want to change this. + derBytes := x509.MarshalPKCS1PrivateKey(key) + pemBlock := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: derBytes, + } + pemBytes := pem.EncodeToMemory(pemBlock) + return string(pemBytes) +} + func keyEntryToECPrivateKey(k *keysutil.KeyEntry, curve elliptic.Curve) (string, error) { if k == nil { return "", errors.New("nil KeyEntry provided") diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index ad9a9188c254..42ce0b9b581a 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -2,7 +2,9 @@ package transit import ( "crypto/elliptic" + "crypto/x509" "encoding/base64" + "encoding/pem" "fmt" "strconv" "time" @@ -40,9 +42,11 @@ func (b *backend) pathKeys() *framework.Path { "type": &framework.FieldSchema{ Type: framework.TypeString, Default: "aes256-gcm96", - Description: `The type of key to create. Currently, -"aes256-gcm96" (symmetric) and "ecdsa-p256" (asymmetric), and -'ed25519' (asymmetric) are supported. Defaults to "aes256-gcm96".`, + Description: ` +The type of key to create. Currently, "aes256-gcm96" (symmetric), "ecdsa-p256" +(asymmetric), 'ed25519' (asymmetric), 'rsa-2048' (asymmetric), 'rsa-4096' +(asymmetric) are supported. Defaults to "aes256-gcm96". +`, }, "derived": &framework.FieldSchema{ @@ -131,6 +135,10 @@ func (b *backend) pathPolicyWrite( polReq.KeyType = keysutil.KeyType_ECDSA_P256 case "ed25519": polReq.KeyType = keysutil.KeyType_ED25519 + case "rsa-2048": + polReq.KeyType = keysutil.KeyType_RSA2048 + case "rsa-4096": + polReq.KeyType = keysutil.KeyType_RSA4096 default: return logical.ErrorResponse(fmt.Sprintf("unknown key type %v", keyType)), logical.ErrInvalidRequest } @@ -225,7 +233,7 @@ func (b *backend) pathPolicyRead( } resp.Data["keys"] = retKeys - case keysutil.KeyType_ECDSA_P256, keysutil.KeyType_ED25519: + case keysutil.KeyType_ECDSA_P256, keysutil.KeyType_ED25519, keysutil.KeyType_RSA2048, keysutil.KeyType_RSA4096: retKeys := map[string]map[string]interface{}{} for k, v := range p.Keys { key := asymKey{ @@ -253,6 +261,27 @@ func (b *backend) pathPolicyRead( } } key.Name = "ed25519" + case keysutil.KeyType_RSA2048, keysutil.KeyType_RSA4096: + key.Name = "rsa-2048" + if p.Type == keysutil.KeyType_RSA4096 { + key.Name = "rsa-4096" + } + + // Encode the RSA public key in PEM format to return over the + // API + derBytes, err := x509.MarshalPKIXPublicKey(v.RSAKey.Public()) + if err != nil { + return nil, fmt.Errorf("error marshaling RSA public key: %v", err) + } + pemBlock := &pem.Block{ + Type: "PUBLIC KEY", + Bytes: derBytes, + } + pemBytes := pem.EncodeToMemory(pemBlock) + if pemBytes == nil || len(pemBytes) == 0 { + return nil, fmt.Errorf("failed to PEM-encode RSA public key") + } + key.PublicKey = string(pemBytes) } retKeys[strconv.Itoa(k)] = structs.New(key).Map() diff --git a/builtin/logical/transit/path_sign_verify.go b/builtin/logical/transit/path_sign_verify.go index 074f7ff22236..6bdc96d3c6f7 100644 --- a/builtin/logical/transit/path_sign_verify.go +++ b/builtin/logical/transit/path_sign_verify.go @@ -37,7 +37,6 @@ derivation is enabled; currently only available with ed25519 keys.`, Default: "sha2-256", Description: `Hash algorithm to use (POST body parameter). Valid values are: -* none * sha2-224 * sha2-256 * sha2-384 @@ -58,6 +57,11 @@ including ed25519.`, Must be 0 (for latest) or a value greater than or equal to the min_encryption_version configured on the key.`, }, + + "prehashed": &framework.FieldSchema{ + Type: framework.TypeBool, + Description: `Set to 'true' when the input is already hashed. If the key type is 'rsa-2048' or 'rsa-4096', then the algorithm used to hash the input should be indicated by the 'algorithm' parameter.`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -109,7 +113,6 @@ derivation is enabled; currently only available with ed25519 keys.`, Default: "sha2-256", Description: `Hash algorithm to use (POST body parameter). Valid values are: -* none * sha2-224 * sha2-256 * sha2-384 @@ -117,6 +120,11 @@ derivation is enabled; currently only available with ed25519 keys.`, Defaults to "sha2-256". Not valid for all key types.`, }, + + "prehashed": &framework.FieldSchema{ + Type: framework.TypeBool, + Description: `Set to 'true' when the input is already hashed. If the key type is 'rsa-2048' or 'rsa-4096', then the algorithm used to hash the input should be indicated by the 'algorithm' parameter.`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -137,6 +145,7 @@ func (b *backend) pathSignWrite( if algorithm == "" { algorithm = d.Get("algorithm").(string) } + prehashed := d.Get("prehashed").(bool) input, err := base64.StdEncoding.DecodeString(inputB64) if err != nil { @@ -168,7 +177,7 @@ func (b *backend) pathSignWrite( } } - if p.Type.HashSignatureInput() && algorithm != "none" { + if p.Type.HashSignatureInput() && !prehashed { var hf hash.Hash switch algorithm { case "sha2-224": @@ -186,7 +195,7 @@ func (b *backend) pathSignWrite( input = hf.Sum(nil) } - sig, err := p.Sign(ver, context, input) + sig, err := p.Sign(ver, context, input, algorithm) if err != nil { return nil, err } @@ -230,6 +239,7 @@ func (b *backend) pathVerifyWrite( if algorithm == "" { algorithm = d.Get("algorithm").(string) } + prehashed := d.Get("prehashed").(bool) input, err := base64.StdEncoding.DecodeString(inputB64) if err != nil { @@ -261,7 +271,7 @@ func (b *backend) pathVerifyWrite( } } - if p.Type.HashSignatureInput() && algorithm != "none" { + if p.Type.HashSignatureInput() && !prehashed { var hf hash.Hash switch algorithm { case "sha2-224": @@ -279,7 +289,7 @@ func (b *backend) pathVerifyWrite( input = hf.Sum(nil) } - valid, err := p.VerifySignature(context, input, sig) + valid, err := p.VerifySignature(context, input, sig, algorithm) if err != nil { switch err.(type) { case errutil.UserError: diff --git a/builtin/logical/transit/path_sign_verify_test.go b/builtin/logical/transit/path_sign_verify_test.go index 1ab994f1da3c..f23f0a8edb8c 100644 --- a/builtin/logical/transit/path_sign_verify_test.go +++ b/builtin/logical/transit/path_sign_verify_test.go @@ -164,9 +164,10 @@ func TestTransit_SignVerify_P256(t *testing.T) { sig = signRequest(req, false, "") verifyRequest(req, false, "", sig) - req.Data["algorithm"] = "none" + req.Data["prehashed"] = true sig = signRequest(req, false, "") verifyRequest(req, false, "", sig) + delete(req.Data, "prehashed") // Test 512 and save sig for later to ensure we can't validate once min // decryption version is set diff --git a/helper/keysutil/lock_manager.go b/helper/keysutil/lock_manager.go index 75881997340e..fc28ea0eee1b 100644 --- a/helper/keysutil/lock_manager.go +++ b/helper/keysutil/lock_manager.go @@ -256,7 +256,13 @@ func (lm *LockManager) getPolicyCommon(req PolicyRequest, lockType bool) (*Polic case KeyType_ED25519: if req.Convergent { lm.UnlockPolicy(lock, lockType) - return nil, nil, false, fmt.Errorf("convergent encryption not not supported for keys of type %v", req.KeyType) + return nil, nil, false, fmt.Errorf("convergent encryption not supported for keys of type %v", req.KeyType) + } + + case KeyType_RSA2048, KeyType_RSA4096: + if req.Derived || req.Convergent { + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %v", req.KeyType) } default: diff --git a/helper/keysutil/policy.go b/helper/keysutil/policy.go index 85591f8e44bd..7ce53e266c77 100644 --- a/helper/keysutil/policy.go +++ b/helper/keysutil/policy.go @@ -9,6 +9,7 @@ import ( "crypto/elliptic" "crypto/hmac" "crypto/rand" + "crypto/rsa" "crypto/sha256" "crypto/x509" "encoding/asn1" @@ -44,6 +45,8 @@ const ( KeyType_AES256_GCM96 = iota KeyType_ECDSA_P256 KeyType_ED25519 + KeyType_RSA2048 + KeyType_RSA4096 ) const ErrTooOld = "ciphertext or signature version is disallowed by policy (too old)" @@ -61,7 +64,7 @@ type KeyType int func (kt KeyType) EncryptionSupported() bool { switch kt { - case KeyType_AES256_GCM96: + case KeyType_AES256_GCM96, KeyType_RSA2048, KeyType_RSA4096: return true } return false @@ -69,7 +72,7 @@ func (kt KeyType) EncryptionSupported() bool { func (kt KeyType) DecryptionSupported() bool { switch kt { - case KeyType_AES256_GCM96: + case KeyType_AES256_GCM96, KeyType_RSA2048, KeyType_RSA4096: return true } return false @@ -77,7 +80,7 @@ func (kt KeyType) DecryptionSupported() bool { func (kt KeyType) SigningSupported() bool { switch kt { - case KeyType_ECDSA_P256, KeyType_ED25519: + case KeyType_ECDSA_P256, KeyType_ED25519, KeyType_RSA2048, KeyType_RSA4096: return true } return false @@ -85,7 +88,7 @@ func (kt KeyType) SigningSupported() bool { func (kt KeyType) HashSignatureInput() bool { switch kt { - case KeyType_ECDSA_P256: + case KeyType_ECDSA_P256, KeyType_RSA2048, KeyType_RSA4096: return true } return false @@ -107,6 +110,10 @@ func (kt KeyType) String() string { return "ecdsa-p256" case KeyType_ED25519: return "ed25519" + case KeyType_RSA2048: + return "rsa-2048" + case KeyType_RSA4096: + return "rsa-4096" } return "[unknown]" @@ -127,6 +134,8 @@ type KeyEntry struct { EC_Y *big.Int `json:"ec_y"` EC_D *big.Int `json:"ec_d"` + RSAKey *rsa.PrivateKey `json:"rsa_key"` + // The public key in an appropriate format for the type of key FormattedPublicKey string `json:"public_key"` @@ -519,13 +528,6 @@ func (p *Policy) Encrypt(ver int, context, nonce []byte, value string) (string, return "", errutil.UserError{Err: fmt.Sprintf("message encryption not supported for key type %v", p.Type)} } - // Guard against a potentially invalid key type - switch p.Type { - case KeyType_AES256_GCM96: - default: - return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)} - } - // Decode the plaintext value plaintext, err := base64.StdEncoding.DecodeString(value) if err != nil { @@ -543,62 +545,69 @@ func (p *Policy) Encrypt(ver int, context, nonce []byte, value string) (string, return "", errutil.UserError{Err: "requested version for encryption is less than the minimum encryption key version"} } - // Derive the key that should be used - key, err := p.DeriveKey(context, ver) - if err != nil { - return "", err - } + var ciphertext []byte - // Guard against a potentially invalid key type switch p.Type { case KeyType_AES256_GCM96: - default: - return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)} - } - - // Setup the cipher - aesCipher, err := aes.NewCipher(key) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } + // Derive the key that should be used + key, err := p.DeriveKey(context, ver) + if err != nil { + return "", err + } - // Setup the GCM AEAD - gcm, err := cipher.NewGCM(aesCipher) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } + // Setup the cipher + aesCipher, err := aes.NewCipher(key) + if err != nil { + return "", errutil.InternalError{Err: err.Error()} + } - if p.ConvergentEncryption { - switch p.ConvergentVersion { - case 1: - if len(nonce) != gcm.NonceSize() { - return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long when using convergent encryption with this key", gcm.NonceSize())} - } - default: - nonceHmac := hmac.New(sha256.New, context) - nonceHmac.Write(plaintext) - nonceSum := nonceHmac.Sum(nil) - nonce = nonceSum[:gcm.NonceSize()] - } - } else { - // Compute random nonce - nonce, err = uuid.GenerateRandomBytes(gcm.NonceSize()) + // Setup the GCM AEAD + gcm, err := cipher.NewGCM(aesCipher) if err != nil { return "", errutil.InternalError{Err: err.Error()} } - } - // Encrypt and tag with GCM - out := gcm.Seal(nil, nonce, plaintext, nil) + if p.ConvergentEncryption { + switch p.ConvergentVersion { + case 1: + if len(nonce) != gcm.NonceSize() { + return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long when using convergent encryption with this key", gcm.NonceSize())} + } + default: + nonceHmac := hmac.New(sha256.New, context) + nonceHmac.Write(plaintext) + nonceSum := nonceHmac.Sum(nil) + nonce = nonceSum[:gcm.NonceSize()] + } + } else { + // Compute random nonce + nonce, err = uuid.GenerateRandomBytes(gcm.NonceSize()) + if err != nil { + return "", errutil.InternalError{Err: err.Error()} + } + } + + // Encrypt and tag with GCM + ciphertext = gcm.Seal(nil, nonce, plaintext, nil) + + // Place the encrypted data after the nonce + if !p.ConvergentEncryption || p.ConvergentVersion > 1 { + ciphertext = append(nonce, ciphertext...) + } + + case KeyType_RSA2048, KeyType_RSA4096: + key := p.Keys[ver].RSAKey + ciphertext, err = rsa.EncryptOAEP(sha256.New(), rand.Reader, &key.PublicKey, plaintext, nil) + if err != nil { + return "", errutil.InternalError{Err: fmt.Sprintf("failed to RSA encrypt the plaintext: %v", err)} + } - // Place the encrypted data after the nonce - full := out - if !p.ConvergentEncryption || p.ConvergentVersion > 1 { - full = append(nonce, out...) + default: + return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)} } // Convert to base64 - encoded := base64.StdEncoding.EncodeToString(full) + encoded := base64.StdEncoding.EncodeToString(ciphertext) // Prepend some information encoded = "vault:v" + strconv.Itoa(ver) + ":" + encoded @@ -644,54 +653,61 @@ func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) { return "", errutil.UserError{Err: ErrTooOld} } - // Derive the key that should be used - key, err := p.DeriveKey(context, ver) + // Decode the base64 + decoded, err := base64.StdEncoding.DecodeString(splitVerCiphertext[1]) if err != nil { - return "", err + return "", errutil.UserError{Err: "invalid ciphertext: could not decode base64"} } - // Guard against a potentially invalid key type + var plain []byte + switch p.Type { case KeyType_AES256_GCM96: - default: - return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)} - } + key, err := p.DeriveKey(context, ver) + if err != nil { + return "", err + } - // Decode the base64 - decoded, err := base64.StdEncoding.DecodeString(splitVerCiphertext[1]) - if err != nil { - return "", errutil.UserError{Err: "invalid ciphertext: could not decode base64"} - } + // Setup the cipher + aesCipher, err := aes.NewCipher(key) + if err != nil { + return "", errutil.InternalError{Err: err.Error()} + } - // Setup the cipher - aesCipher, err := aes.NewCipher(key) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } + // Setup the GCM AEAD + gcm, err := cipher.NewGCM(aesCipher) + if err != nil { + return "", errutil.InternalError{Err: err.Error()} + } - // Setup the GCM AEAD - gcm, err := cipher.NewGCM(aesCipher) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } + if len(decoded) < gcm.NonceSize() { + return "", errutil.UserError{Err: "invalid ciphertext length"} + } - if len(decoded) < gcm.NonceSize() { - return "", errutil.UserError{Err: "invalid ciphertext length"} - } + // Extract the nonce and ciphertext + var ciphertext []byte + if p.ConvergentEncryption && p.ConvergentVersion < 2 { + ciphertext = decoded + } else { + nonce = decoded[:gcm.NonceSize()] + ciphertext = decoded[gcm.NonceSize():] + } - // Extract the nonce and ciphertext - var ciphertext []byte - if p.ConvergentEncryption && p.ConvergentVersion < 2 { - ciphertext = decoded - } else { - nonce = decoded[:gcm.NonceSize()] - ciphertext = decoded[gcm.NonceSize():] - } + // Verify and Decrypt + plain, err = gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return "", errutil.UserError{Err: "invalid ciphertext: unable to decrypt"} + } - // Verify and Decrypt - plain, err := gcm.Open(nil, nonce, ciphertext, nil) - if err != nil { - return "", errutil.UserError{Err: "invalid ciphertext: unable to decrypt"} + case KeyType_RSA2048, KeyType_RSA4096: + key := p.Keys[ver].RSAKey + plain, err = rsa.DecryptOAEP(sha256.New(), rand.Reader, key, decoded, nil) + if err != nil { + return "", errutil.InternalError{Err: fmt.Sprintf("failed to RSA decrypt the ciphertext: %v", err)} + } + + default: + return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)} } return base64.StdEncoding.EncodeToString(plain), nil @@ -712,7 +728,7 @@ func (p *Policy) HMACKey(version int) ([]byte, error) { return p.Keys[version].HMACKey, nil } -func (p *Policy) Sign(ver int, context, input []byte) (*SigningResult, error) { +func (p *Policy) Sign(ver int, context, input []byte, algorithm string) (*SigningResult, error) { if !p.Type.SigningSupported() { return nil, fmt.Errorf("message signing not supported for key type %v", p.Type) } @@ -777,6 +793,28 @@ func (p *Policy) Sign(ver int, context, input []byte) (*SigningResult, error) { return nil, err } + case KeyType_RSA2048, KeyType_RSA4096: + key := p.Keys[ver].RSAKey + + var algo crypto.Hash + switch algorithm { + case "sha2-224": + algo = crypto.SHA224 + case "sha2-256": + algo = crypto.SHA256 + case "sha2-384": + algo = crypto.SHA384 + case "sha2-512": + algo = crypto.SHA512 + default: + return nil, errutil.InternalError{Err: fmt.Sprintf("unsupported algorithm %s", algorithm)} + } + + sig, err = rsa.SignPSS(rand.Reader, key, algo, input, nil) + if err != nil { + return nil, err + } + default: return nil, fmt.Errorf("unsupported key type %v", p.Type) } @@ -792,7 +830,7 @@ func (p *Policy) Sign(ver int, context, input []byte) (*SigningResult, error) { return res, nil } -func (p *Policy) VerifySignature(context, input []byte, sig string) (bool, error) { +func (p *Policy) VerifySignature(context, input []byte, sig, algorithm string) (bool, error) { if !p.Type.SigningSupported() { return false, errutil.UserError{Err: fmt.Sprintf("message verification not supported for key type %v", p.Type)} } @@ -861,6 +899,27 @@ func (p *Policy) VerifySignature(context, input []byte, sig string) (bool, error return ed25519.Verify(key.Public().(ed25519.PublicKey), input, sigBytes), nil + case KeyType_RSA2048, KeyType_RSA4096: + key := p.Keys[ver].RSAKey + + var algo crypto.Hash + switch algorithm { + case "sha2-224": + algo = crypto.SHA224 + case "sha2-256": + algo = crypto.SHA256 + case "sha2-384": + algo = crypto.SHA384 + case "sha2-512": + algo = crypto.SHA512 + default: + return false, errutil.InternalError{Err: fmt.Sprintf("unsupported algorithm %s", algorithm)} + } + + err = rsa.VerifyPSS(&key.PublicKey, algo, input, sigBytes, nil) + + return err == nil, nil + default: return false, errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)} } @@ -927,6 +986,17 @@ func (p *Policy) Rotate(storage logical.Storage) error { } entry.Key = pri entry.FormattedPublicKey = base64.StdEncoding.EncodeToString(pub) + + case KeyType_RSA2048, KeyType_RSA4096: + bitSize := 2048 + if p.Type == KeyType_RSA4096 { + bitSize = 4096 + } + + entry.RSAKey, err = rsa.GenerateKey(rand.Reader, bitSize) + if err != nil { + return err + } } p.Keys[p.LatestVersion] = entry diff --git a/website/source/api/secret/transit/index.html.md b/website/source/api/secret/transit/index.html.md index fa4de364a119..eaf32817a51d 100644 --- a/website/source/api/secret/transit/index.html.md +++ b/website/source/api/secret/transit/index.html.md @@ -52,6 +52,8 @@ values set here cannot be changed after key creation. (symmetric, supports derivation) - `ecdsa-p256` – ECDSA using the P-256 elliptic curve (asymmetric) - `ed25519` – ED25519 (asymmetric, supports derivation) + - `rsa-2048` - RSA with bit size of 2048 (asymmetric) + - `rsa-4096` - RSA with bit size of 4096 (asymmetric) ### Sample Payload