diff --git a/builtin/logical/pki/backend_test.go b/builtin/logical/pki/backend_test.go index 823dcf427dcb..2ac08590767a 100644 --- a/builtin/logical/pki/backend_test.go +++ b/builtin/logical/pki/backend_test.go @@ -4656,6 +4656,23 @@ func TestRootWithExistingKey(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "unable to find PKI key for reference: my-key1") + // Fail if the specified key name is default. + _, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/internal", map[string]interface{}{ + "common_name": "root myvault.com", + "issuer_name": "my-issuer1", + "key_name": "Default", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "reserved keyword 'default' can not be used as key name") + + // Fail if the specified issuer name is default. + _, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/internal", map[string]interface{}{ + "common_name": "root myvault.com", + "issuer_name": "DEFAULT", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "reserved keyword 'default' can not be used as issuer name") + // Create the first CA resp, err := client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/internal", map[string]interface{}{ "common_name": "root myvault.com", @@ -4669,11 +4686,20 @@ func TestRootWithExistingKey(t *testing.T) { require.NotEmpty(t, myIssuerId1) require.NotEmpty(t, myKeyId1) + // Fail if the specified issuer name is re-used. + _, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/internal", map[string]interface{}{ + "common_name": "root myvault.com", + "issuer_name": "my-issuer1", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "issuer name already used") + // Create the second CA resp, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/internal", map[string]interface{}{ "common_name": "root myvault.com", "key_type": "rsa", "issuer_name": "my-issuer2", + "key_name": "root-key2", }) require.NoError(t, err) require.NotNil(t, resp.Data["certificate"]) @@ -4682,6 +4708,15 @@ func TestRootWithExistingKey(t *testing.T) { require.NotEmpty(t, myIssuerId2) require.NotEmpty(t, myKeyId2) + // Fail if the specified key name is re-used. + _, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/internal", map[string]interface{}{ + "common_name": "root myvault.com", + "issuer_name": "my-issuer3", + "key_name": "root-key2", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "key name already used") + // Create a third CA re-using key from CA 1 resp, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/existing", map[string]interface{}{ "common_name": "root myvault.com", @@ -4764,6 +4799,7 @@ func TestIntermediateWithExistingKey(t *testing.T) { resp, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/intermediate/internal", map[string]interface{}{ "common_name": "root myvault.com", "key_type": "rsa", + "key_name": "interkey1", }) require.NoError(t, err) // csr2 := resp.Data["csr"] diff --git a/builtin/logical/pki/ca_util.go b/builtin/logical/pki/ca_util.go index b9d6616f2736..41987217c030 100644 --- a/builtin/logical/pki/ca_util.go +++ b/builtin/logical/pki/ca_util.go @@ -116,7 +116,7 @@ func getKeyTypeAndBitsForRole(ctx context.Context, b *backend, data *framework.F } func getExistingPublicKey(ctx context.Context, s logical.Storage, data *framework.FieldData) (crypto.PublicKey, error) { - keyRef, err := getExistingKeyRef(data) + keyRef, err := getKeyRefWithErr(data) if err != nil { return nil, err } diff --git a/builtin/logical/pki/managed_key_util.go b/builtin/logical/pki/managed_key_util.go index caff650958d7..e3569cf3433a 100644 --- a/builtin/logical/pki/managed_key_util.go +++ b/builtin/logical/pki/managed_key_util.go @@ -19,7 +19,7 @@ func generateCABundle(ctx context.Context, _ *backend, input *inputBundle, data return nil, errEntOnly } if existingKeyRequested(input) { - keyRef, err := getExistingKeyRef(input.apiData) + keyRef, err := getKeyRefWithErr(input.apiData) if err != nil { return nil, err } @@ -33,7 +33,7 @@ func generateCSRBundle(ctx context.Context, _ *backend, input *inputBundle, data return nil, errEntOnly } if existingKeyRequested(input) { - keyRef, err := getExistingKeyRef(input.apiData) + keyRef, err := getKeyRefWithErr(input.apiData) if err != nil { return nil, err } diff --git a/builtin/logical/pki/path_fetch_issuers.go b/builtin/logical/pki/path_fetch_issuers.go index a85eb81da457..28e5f46f3ca4 100644 --- a/builtin/logical/pki/path_fetch_issuers.go +++ b/builtin/logical/pki/path_fetch_issuers.go @@ -25,7 +25,7 @@ func pathListIssuers(b *backend) *framework.Path { } } -func (b *backend) pathListIssuersHandler(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { +func (b *backend) pathListIssuersHandler(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) { var responseKeys []string responseInfo := make(map[string]interface{}) @@ -90,7 +90,7 @@ func (b *backend) pathGetIssuer(ctx context.Context, req *logical.Request, data return b.pathGetRawIssuer(ctx, req, data) } - issuerName := data.Get("issuer_ref").(string) + issuerName := getIssuerRef(data) if len(issuerName) == 0 { return logical.ErrorResponse("missing issuer reference"), nil } @@ -119,14 +119,14 @@ func (b *backend) pathGetIssuer(ctx context.Context, req *logical.Request, data } func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - issuerName := data.Get("issuer_ref").(string) + issuerName := getIssuerRef(data) if len(issuerName) == 0 { return logical.ErrorResponse("missing issuer reference"), nil } - newName := data.Get("issuer_name").(string) - if len(newName) > 0 && !nameMatcher.MatchString(newName) { - return logical.ErrorResponse("new issuer name outside of valid character limits"), nil + newName, err := getIssuerName(ctx, req.Storage, data) + if err != nil { + return logical.ErrorResponse(err.Error()), nil } ref, err := resolveIssuerReference(ctx, req.Storage, issuerName) @@ -162,7 +162,7 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da } func (b *backend) pathGetRawIssuer(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - issuerName := data.Get("issuer_ref").(string) + issuerName := getIssuerRef(data) if len(issuerName) == 0 { return logical.ErrorResponse("missing issuer reference"), nil } @@ -209,7 +209,7 @@ func (b *backend) pathGetRawIssuer(ctx context.Context, req *logical.Request, da } func (b *backend) pathDeleteIssuer(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - issuerName := data.Get("issuer_ref").(string) + issuerName := getIssuerRef(data) if len(issuerName) == 0 { return logical.ErrorResponse("missing issuer reference"), nil } diff --git a/builtin/logical/pki/path_intermediate.go b/builtin/logical/pki/path_intermediate.go index 6828c20b0a00..88dbef9c5e64 100644 --- a/builtin/logical/pki/path_intermediate.go +++ b/builtin/logical/pki/path_intermediate.go @@ -52,6 +52,10 @@ func (b *backend) pathGenerateIntermediate(ctx context.Context, req *logical.Req return errorResp, nil } + keyName, err := getKeyName(ctx, req.Storage, data) + if err != nil { + return logical.ErrorResponse(err.Error()), nil + } var resp *logical.Response input := &inputBundle{ role: role, @@ -110,7 +114,7 @@ func (b *backend) pathGenerateIntermediate(ctx context.Context, req *logical.Req } } - myKey, _, err := importKey(ctx, req.Storage, csrb.PrivateKey) + myKey, _, err := importKey(ctx, req.Storage, csrb.PrivateKey, keyName) if err != nil { return nil, err } diff --git a/builtin/logical/pki/path_issue_sign.go b/builtin/logical/pki/path_issue_sign.go index 6d72278b4067..4aae8c278cdd 100644 --- a/builtin/logical/pki/path_issue_sign.go +++ b/builtin/logical/pki/path_issue_sign.go @@ -218,7 +218,7 @@ func (b *backend) pathIssueSignCert(ctx context.Context, req *logical.Request, d return nil, logical.ErrReadOnly } - issuerName := data.Get("issuer_ref").(string) + issuerName := getIssuerRef(data) if len(issuerName) == 0 { return logical.ErrorResponse("missing issuer reference"), nil } diff --git a/builtin/logical/pki/path_manage_issuers.go b/builtin/logical/pki/path_manage_issuers.go index 2357352fbf0a..d4bb435a7580 100644 --- a/builtin/logical/pki/path_manage_issuers.go +++ b/builtin/logical/pki/path_manage_issuers.go @@ -138,7 +138,7 @@ func (b *backend) pathImportIssuers(ctx context.Context, req *logical.Request, d for _, keyPem := range keys { // Handle import of private key. - key, existing, err := importKey(ctx, req.Storage, keyPem) + key, existing, err := importKey(ctx, req.Storage, keyPem, "") if err != nil { return logical.ErrorResponse(err.Error()), nil } diff --git a/builtin/logical/pki/path_root.go b/builtin/logical/pki/path_root.go index 8aa25558a859..8b650b78043a 100644 --- a/builtin/logical/pki/path_root.go +++ b/builtin/logical/pki/path_root.go @@ -77,10 +77,13 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request, role.MaxPathLength = &maxPathLength } - issuerName := "" - issuerNameIface, ok := data.GetOk("id") - if ok { - issuerName = issuerNameIface.(string) + issuerName, err := getIssuerName(ctx, req.Storage, data) + if err != nil { + return logical.ErrorResponse(err.Error()), nil + } + keyName, err := getKeyName(ctx, req.Storage, data) + if err != nil { + return logical.ErrorResponse(err.Error()), nil } input := &inputBundle{ @@ -149,7 +152,7 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request, } // Store it as the CA bundle - myIssuer, myKey, err := writeCaBundle(ctx, req.Storage, cb, issuerName) + myIssuer, myKey, err := writeCaBundle(ctx, req.Storage, cb, issuerName, keyName) if err != nil { return nil, err } @@ -192,7 +195,7 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request, func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { var err error - issuerName := data.Get("issuer_ref").(string) + issuerName := getIssuerRef(data) if len(issuerName) == 0 { return logical.ErrorResponse("missing issuer reference"), nil } @@ -341,7 +344,7 @@ func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.R func (b *backend) pathIssuerSignSelfIssued(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { var err error - issuerName := data.Get("issuer_ref").(string) + issuerName := getIssuerRef(data) if len(issuerName) == 0 { return logical.ErrorResponse("missing issuer reference"), nil } diff --git a/builtin/logical/pki/storage.go b/builtin/logical/pki/storage.go index 8799caf1112e..17451e68eab4 100644 --- a/builtin/logical/pki/storage.go +++ b/builtin/logical/pki/storage.go @@ -36,6 +36,11 @@ func (p issuerId) String() string { return string(p) } +const ( + IssuerRefNotFound = issuerId("not-found") + KeyRefNotFound = keyId("not-found") +) + type key struct { ID keyId `json:"id" structs:"id" mapstructure:"id"` Name string `json:"name" structs:"name" mapstructure:"name"` @@ -112,7 +117,7 @@ func deleteKey(ctx context.Context, s logical.Storage, id keyId) error { return s.Delete(ctx, keyPrefix+id.String()) } -func importKey(ctx context.Context, s logical.Storage, keyValue string) (*key, bool, error) { +func importKey(ctx context.Context, s logical.Storage, keyValue string, keyName string) (*key, bool, error) { // importKey imports the specified PEM-format key (from keyValue) into // the new PKI storage format. The first return field is a reference to // the new key; the second is whether or not the key already existed @@ -154,6 +159,7 @@ func importKey(ctx context.Context, s logical.Storage, keyValue string) (*key, b // Haven't found a key, so we've gotta create it and write it into storage. var result key result.ID = genKeyId() + result.Name = keyName result.PrivateKey = keyValue // Extracting the signer is necessary for two reasons: first, to get the @@ -270,7 +276,7 @@ func resolveKeyReference(ctx context.Context, s logical.Storage, reference strin } // Otherwise, we must not have found the key. - return keyId("not-found"), errutil.UserError{Err: fmt.Sprintf("unable to find PKI key for reference: %v", reference)} + return KeyRefNotFound, errutil.UserError{Err: fmt.Sprintf("unable to find PKI key for reference: %v", reference)} } func fetchIssuerById(ctx context.Context, s logical.Storage, issuerId issuerId) (*issuer, error) { @@ -488,7 +494,7 @@ func resolveIssuerReference(ctx context.Context, s logical.Storage, reference st } // Otherwise, we must not have found the issuer. - return issuerId("not-found"), errutil.UserError{Err: fmt.Sprintf("unable to find PKI issuer for reference: %v", reference)} + return IssuerRefNotFound, errutil.UserError{Err: fmt.Sprintf("unable to find PKI issuer for reference: %v", reference)} } // Builds a certutil.CertBundle from the specified issuer identifier, @@ -519,7 +525,7 @@ func fetchCertBundleByIssuerId(ctx context.Context, s logical.Storage, id issuer return &bundle, nil } -func writeCaBundle(ctx context.Context, s logical.Storage, caBundle *certutil.CertBundle, issuerName string) (*issuer, *key, error) { +func writeCaBundle(ctx context.Context, s logical.Storage, caBundle *certutil.CertBundle, issuerName string, keyName string) (*issuer, *key, error) { allKeyIds, err := listKeys(ctx, s) if err != nil { return nil, nil, err @@ -530,7 +536,7 @@ func writeCaBundle(ctx context.Context, s logical.Storage, caBundle *certutil.Ce return nil, nil, err } - myKey, _, err := importKey(ctx, s, caBundle.PrivateKey) + myKey, _, err := importKey(ctx, s, caBundle.PrivateKey, keyName) if err != nil { return nil, nil, err } diff --git a/builtin/logical/pki/storage_migrations.go b/builtin/logical/pki/storage_migrations.go index c1e05263a77e..23c1b2399700 100644 --- a/builtin/logical/pki/storage_migrations.go +++ b/builtin/logical/pki/storage_migrations.go @@ -50,7 +50,7 @@ func migrateStorage(ctx context.Context, req *logical.InitializationRequest, log logger.Warn("performing PKI migration to new keys/issuers layout") - anIssuer, aKey, err := writeCaBundle(ctx, s, legacyBundle, "") + anIssuer, aKey, err := writeCaBundle(ctx, s, legacyBundle, "", "") if err != nil { return err } diff --git a/builtin/logical/pki/storage_test.go b/builtin/logical/pki/storage_test.go index b96118a44827..b1b4e277b6ac 100644 --- a/builtin/logical/pki/storage_test.go +++ b/builtin/logical/pki/storage_test.go @@ -102,29 +102,32 @@ func Test_KeysIssuerImport(t *testing.T) { issuer1.ID = "" issuer1.KeyID = "" - key1_ref1, existing, err := importKey(ctx, s, key1.PrivateKey) + key1_ref1, existing, err := importKey(ctx, s, key1.PrivateKey, "key1") require.NoError(t, err) require.False(t, existing) require.Equal(t, key1.PrivateKey, key1_ref1.PrivateKey) - key1_ref2, existing, err := importKey(ctx, s, key1.PrivateKey) + key1_ref2, existing, err := importKey(ctx, s, key1.PrivateKey, "ignore-me") require.NoError(t, err) require.True(t, existing) require.Equal(t, key1.PrivateKey, key1_ref1.PrivateKey) require.Equal(t, key1_ref1.ID, key1_ref2.ID) + require.Equal(t, key1_ref1.Name, key1_ref2.Name) - issuer1_ref1, existing, err := importIssuer(ctx, s, issuer1.Certificate, "") + issuer1_ref1, existing, err := importIssuer(ctx, s, issuer1.Certificate, "issuer1") require.NoError(t, err) require.False(t, existing) require.Equal(t, issuer1.Certificate, issuer1_ref1.Certificate) require.Equal(t, key1_ref1.ID, issuer1_ref1.KeyID) + require.Equal(t, "issuer1", issuer1_ref1.Name) - issuer1_ref2, existing, err := importIssuer(ctx, s, issuer1.Certificate, "") + issuer1_ref2, existing, err := importIssuer(ctx, s, issuer1.Certificate, "ignore-me") require.NoError(t, err) require.True(t, existing) require.Equal(t, issuer1.Certificate, issuer1_ref1.Certificate) require.Equal(t, issuer1_ref1.ID, issuer1_ref2.ID) require.Equal(t, key1_ref1.ID, issuer1_ref2.KeyID) + require.Equal(t, issuer1_ref1.Name, issuer1_ref2.Name) err = writeIssuer(ctx, s, &issuer2) require.NoError(t, err) @@ -132,18 +135,20 @@ func Test_KeysIssuerImport(t *testing.T) { err = writeKey(ctx, s, key2) require.NoError(t, err) - issuer2_ref, existing, err := importIssuer(ctx, s, issuer2.Certificate, "") + issuer2_ref, existing, err := importIssuer(ctx, s, issuer2.Certificate, "ignore-me") require.NoError(t, err) require.True(t, existing) require.Equal(t, issuer2.Certificate, issuer2_ref.Certificate) - require.Equal(t, issuer2_ref.ID, issuer2.ID) - require.Equal(t, issuer2_ref.KeyID, issuer2.KeyID) + require.Equal(t, issuer2.ID, issuer2_ref.ID) + require.Equal(t, "", issuer2_ref.Name) + require.Equal(t, issuer2.KeyID, issuer2_ref.KeyID) - key2_ref, existing, err := importKey(ctx, s, key2.PrivateKey) + key2_ref, existing, err := importKey(ctx, s, key2.PrivateKey, "ignore-me") require.NoError(t, err) require.True(t, existing) require.Equal(t, key2.PrivateKey, key2_ref.PrivateKey) - require.Equal(t, key2_ref.ID, key2.ID) + require.Equal(t, key2.ID, key2_ref.ID) + require.Equal(t, "", key2_ref.Name) } func genIssuerAndKey(t *testing.T, b *backend) (issuer, key) { diff --git a/builtin/logical/pki/util.go b/builtin/logical/pki/util.go index 03413ee6e1b1..5e92134739ac 100644 --- a/builtin/logical/pki/util.go +++ b/builtin/logical/pki/util.go @@ -1,9 +1,12 @@ package pki import ( + "context" "fmt" "strings" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/errutil" @@ -79,17 +82,14 @@ func getManagedKeyId(data *framework.FieldData) (managedKeyId, error) { return keyId, nil } -func getExistingKeyRef(data *framework.FieldData) (string, error) { - keyRef, ok := data.GetOk("key_ref") - if !ok { - return "", errutil.UserError{Err: fmt.Sprintf("missing argument key_ref for existing type")} - } - trimmedKeyRef := strings.TrimSpace(keyRef.(string)) - if len(trimmedKeyRef) == 0 { +func getKeyRefWithErr(data *framework.FieldData) (string, error) { + keyRef := getKeyRef(data) + + if len(keyRef) == 0 { return "", errutil.UserError{Err: fmt.Sprintf("missing argument key_ref for existing type")} } - return trimmedKeyRef, nil + return keyRef, nil } func getManagedKeyNameOrUUID(data *framework.FieldData) (name string, UUID string, err error) { @@ -122,3 +122,75 @@ func getManagedKeyNameOrUUID(data *framework.FieldData) (name string, UUID strin return keyName, keyUUID, nil } + +func getIssuerName(ctx context.Context, s logical.Storage, data *framework.FieldData) (string, error) { + issuerName := "" + issuerNameIface, ok := data.GetOk("issuer_name") + if ok { + issuerName = strings.TrimSpace(issuerNameIface.(string)) + + if strings.ToLower(issuerName) == "default" { + return "", errutil.UserError{Err: "reserved keyword 'default' can not be used as issuer name"} + } + + if !nameMatcher.MatchString(issuerName) { + return "", errutil.UserError{Err: "issuer name contained invalid characters"} + } + issuer_id, err := resolveIssuerReference(ctx, s, issuerName) + if err == nil { + return "", errutil.UserError{Err: "issuer name already used."} + } + + if err != nil && issuer_id != IssuerRefNotFound { + return "", errutil.InternalError{Err: err.Error()} + } + } + return issuerName, nil +} + +func getKeyName(ctx context.Context, s logical.Storage, data *framework.FieldData) (string, error) { + keyName := "" + keyNameIface, ok := data.GetOk("key_name") + if ok { + keyName = strings.TrimSpace(keyNameIface.(string)) + + if strings.ToLower(keyName) == "default" { + return "", errutil.UserError{Err: "reserved keyword 'default' can not be used as key name"} + } + + if !nameMatcher.MatchString(keyName) { + return "", errutil.UserError{Err: "key name contained invalid characters"} + } + key_id, err := resolveKeyReference(ctx, s, keyName) + if err == nil { + return "", errutil.UserError{Err: "key name already used."} + } + + if err != nil && key_id != KeyRefNotFound { + return "", errutil.InternalError{Err: err.Error()} + } + } + return keyName, nil +} + +func getIssuerRef(data *framework.FieldData) string { + return extractRef(data, "issuer_ref") +} + +func getKeyRef(data *framework.FieldData) string { + return extractRef(data, "key_ref") +} + +func extractRef(data *framework.FieldData, paramName string) string { + value := "" + issuerNameIface, ok := data.GetOk(paramName) + if ok { + value = strings.TrimSpace(issuerNameIface.(string)) + if strings.ToLower(value) == "default" { + return "default" + } + return value + } + + return value +}