From a76ec20981579633a20f8e6eae9aa9dbd2a96010 Mon Sep 17 00:00:00 2001 From: Steve Clark Date: Fri, 8 Apr 2022 12:51:59 -0400 Subject: [PATCH] WIP: Support root issuer generation --- builtin/logical/pki/backend.go | 1 + builtin/logical/pki/backend_test.go | 95 +++++++++++++ builtin/logical/pki/ca_util.go | 155 +++++++++++++++------ builtin/logical/pki/config_util.go | 19 +++ builtin/logical/pki/fields.go | 12 ++ builtin/logical/pki/managed_key_util.go | 34 ++++- builtin/logical/pki/path_manage_issuers.go | 26 +++- builtin/logical/pki/path_root.go | 21 +-- builtin/logical/pki/storage.go | 53 ++++++- builtin/logical/pki/storage_migrations.go | 33 +---- builtin/logical/pki/storage_test.go | 6 +- builtin/logical/pki/util.go | 31 ++++- 12 files changed, 396 insertions(+), 90 deletions(-) diff --git a/builtin/logical/pki/backend.go b/builtin/logical/pki/backend.go index 6ab418f03bbf..98257e0c51eb 100644 --- a/builtin/logical/pki/backend.go +++ b/builtin/logical/pki/backend.go @@ -116,6 +116,7 @@ func Backend(conf *logical.BackendConfig) *backend { pathIssuerSignIntermediate(&b), pathIssuerSignSelfIssued(&b), pathIssuerSignVerbatim(&b), + pathIssuerGenerateRoot(&b), pathConfigIssuers(&b), // Fetch APIs have been lowered to favor the newer issuer API endpoints diff --git a/builtin/logical/pki/backend_test.go b/builtin/logical/pki/backend_test.go index 53abd6128bcb..8120936d4af8 100644 --- a/builtin/logical/pki/backend_test.go +++ b/builtin/logical/pki/backend_test.go @@ -4613,6 +4613,101 @@ func TestBackend_Roles_KeySizeRegression(t *testing.T) { t.Log(fmt.Sprintf("Key size regression expanded matrix test scenarios: %d", tested)) } +func TestRootWithExistingKey(t *testing.T) { + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "pki": Factory, + }, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + + client := cluster.Cores[0].Client + var err error + + mountPKIEndpoint(t, client, "pki-root") + + // Fail requests if type is existing, and we specify the key_type param + ctx := context.Background() + _, err = client.Logical().WriteWithContext(ctx, "pki-root/root/generate/existing", map[string]interface{}{ + "common_name": "root myvault.com", + "key_type": "rsa", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "key_type nor key_bits arguments can be set in this mode") + + // Fail requests if type is existing, and we specify the key_bits param + _, err = client.Logical().WriteWithContext(ctx, "pki-root/root/generate/existing", map[string]interface{}{ + "common_name": "root myvault.com", + "key_bits": "2048", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "key_type nor key_bits arguments can be set in this mode") + + // Fail if the specified key does not exist. + _, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/existing", map[string]interface{}{ + "common_name": "root myvault.com", + "id": "my-issuer1", + "key_id": "my-key1", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "unable to find PKI key for reference: my-key1") + + // Create the first CA + resp, err := client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/internal", map[string]interface{}{ + "common_name": "root myvault.com", + "key_type": "rsa", + "id": "my-issuer1", + }) + require.NoError(t, err) + require.NotNil(t, resp.Data["certificate"]) + myIssuerId1 := resp.Data["id"] + myKeyId1 := resp.Data["key_id"] + require.NotEmpty(t, myIssuerId1) + require.NotEmpty(t, myKeyId1) + + // Create the second CA + resp, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/internal", map[string]interface{}{ + "common_name": "root myvault.com", + "key_type": "rsa", + "id": "my-issuer2", + }) + require.NoError(t, err) + require.NotNil(t, resp.Data["certificate"]) + myIssuerId2 := resp.Data["id"] + myKeyId2 := resp.Data["key_id"] + require.NotEmpty(t, myIssuerId2) + require.NotEmpty(t, myKeyId2) + + // Create a third CA re-using key from CA 1 + resp, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/existing", map[string]interface{}{ + "common_name": "root myvault.com", + "id": "my-issuer3", + "key_id": myKeyId1, + }) + require.NoError(t, err) + require.NotNil(t, resp.Data["certificate"]) + myIssuerId3 := resp.Data["id"] + myKeyId3 := resp.Data["key_id"] + require.NotEmpty(t, myIssuerId3) + require.NotEmpty(t, myKeyId3) + + require.NotEqual(t, myIssuerId1, myIssuerId2) + require.NotEqual(t, myIssuerId1, myIssuerId3) + require.NotEqual(t, myKeyId1, myKeyId2) + require.Equal(t, myKeyId1, myKeyId3) + + resp, err = client.Logical().ListWithContext(ctx, "pki-root/issuers") + require.NoError(t, err) + require.Equal(t, 3, len(resp.Data["keys"].([]interface{}))) + require.Contains(t, resp.Data["keys"], myIssuerId1) + require.Contains(t, resp.Data["keys"], myIssuerId2) + require.Contains(t, resp.Data["keys"], myIssuerId3) +} + var ( initTest sync.Once rsaCAKey string diff --git a/builtin/logical/pki/ca_util.go b/builtin/logical/pki/ca_util.go index e7a9e8700488..b9d6616f2736 100644 --- a/builtin/logical/pki/ca_util.go +++ b/builtin/logical/pki/ca_util.go @@ -2,8 +2,10 @@ package pki import ( "context" + "crypto" "crypto/ecdsa" "crypto/rsa" + "errors" "fmt" "time" @@ -14,18 +16,17 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) -func (b *backend) getGenerationParams(ctx context.Context, - data *framework.FieldData, mountPoint string, -) (exported bool, format string, role *roleEntry, errorResp *logical.Response) { +func (b *backend) getGenerationParams(ctx context.Context, data *framework.FieldData, mountPoint string) (exported bool, format string, role *roleEntry, errorResp *logical.Response) { exportedStr := data.Get("exported").(string) switch exportedStr { case "exported": exported = true case "internal": + case "existing": case "kms": default: errorResp = logical.ErrorResponse( - `the "exported" path parameter must be "internal", "exported" or "kms"`) + `the "exported" path parameter must be "internal", "existing", exported" or "kms"`) return } @@ -36,46 +37,10 @@ func (b *backend) getGenerationParams(ctx context.Context, return } - keyType := data.Get("key_type").(string) - keyBits := data.Get("key_bits").(int) - if exportedStr == "kms" { - _, okKeyType := data.Raw["key_type"] - _, okKeyBits := data.Raw["key_bits"] - - if okKeyType || okKeyBits { - errorResp = logical.ErrorResponse( - `invalid parameter for the kms path parameter, key_type nor key_bits arguments can be set in this mode`) - return - } - - keyId, err := getManagedKeyId(data) - if err != nil { - errorResp = logical.ErrorResponse("unable to determine managed key id") - return - } - // Determine key type and key bits from the managed public key - err = withManagedPKIKey(ctx, b, keyId, mountPoint, func(ctx context.Context, key logical.ManagedSigningKey) error { - pubKey, err := key.GetPublicKey(ctx) - if err != nil { - return err - } - switch pubKey.(type) { - case *rsa.PublicKey: - keyType = "rsa" - keyBits = pubKey.(*rsa.PublicKey).Size() * 8 - case *ecdsa.PublicKey: - keyType = "ec" - case *ed25519.PublicKey: - keyType = "ed25519" - default: - return fmt.Errorf("unsupported public key: %#v", pubKey) - } - return nil - }) - if err != nil { - errorResp = logical.ErrorResponse("failed to lookup public key from managed key: %s", err.Error()) - return - } + keyType, keyBits, err := getKeyTypeAndBitsForRole(ctx, b, data, mountPoint) + if err != nil { + errorResp = logical.ErrorResponse(err.Error()) + return } role = &roleEntry{ @@ -101,10 +66,110 @@ func (b *backend) getGenerationParams(ctx context.Context, } *role.AllowWildcardCertificates = true - var err error if role.KeyBits, role.SignatureBits, err = certutil.ValidateDefaultOrValueKeyTypeSignatureLength(role.KeyType, role.KeyBits, role.SignatureBits); err != nil { errorResp = logical.ErrorResponse(err.Error()) } return } + +func getKeyTypeAndBitsForRole(ctx context.Context, b *backend, data *framework.FieldData, mountPoint string) (string, int, error) { + exportedStr := data.Get("exported").(string) + var keyType string + var keyBits int + + switch exportedStr { + case "internal": + fallthrough + case "exported": + keyType = data.Get("key_type").(string) + keyBits = data.Get("key_bits").(int) + return keyType, keyBits, nil + } + + // existing and kms types don't support providing the key_type and key_bits args. + _, okKeyType := data.Raw["key_type"] + _, okKeyBits := data.Raw["key_bits"] + + if okKeyType || okKeyBits { + return "", 0, errors.New("invalid parameter for the kms/existing path parameter, key_type nor key_bits arguments can be set in this mode") + } + + var pubKey crypto.PublicKey + if kmsRequestedFromFieldData(data) { + pubKeyManagedKey, err := getManagedKeyPublicKey(ctx, b, data, mountPoint) + if err != nil { + return "", 0, errors.New("failed to lookup public key from managed key: " + err.Error()) + } + pubKey = pubKeyManagedKey + } + + if existingKeyRequestedFromFieldData(data) { + existingPubKey, err := getExistingPublicKey(ctx, b.storage, data) + if err != nil { + return "", 0, errors.New("failed to lookup public key from existing key: " + err.Error()) + } + pubKey = existingPubKey + } + + return getKeyTypeAndBitsFromPublicKeyForRole(pubKey) +} + +func getExistingPublicKey(ctx context.Context, s logical.Storage, data *framework.FieldData) (crypto.PublicKey, error) { + keyRef, err := getExistingKeyRef(data) + if err != nil { + return nil, err + } + id, err := resolveKeyReference(ctx, s, keyRef) + if err != nil { + return nil, err + } + key, err := fetchKeyById(ctx, s, id) + if err != nil { + return nil, err + } + signer, err := key.GetSigner() + if err != nil { + return nil, err + } + return signer.Public(), nil +} + +func getKeyTypeAndBitsFromPublicKeyForRole(pubKey crypto.PublicKey) (string, int, error) { + var keyType string + var keyBits int + + switch pubKey.(type) { + case *rsa.PublicKey: + keyType = "rsa" + keyBits = certutil.GetPublicKeySize(pubKey) + case *ecdsa.PublicKey: + keyType = "ec" + case *ed25519.PublicKey: + keyType = "ed25519" + default: + return "", 0, fmt.Errorf("unsupported public key: %#v", pubKey) + } + return keyType, keyBits, nil +} + +func getManagedKeyPublicKey(ctx context.Context, b *backend, data *framework.FieldData, mountPoint string) (crypto.PublicKey, error) { + keyId, err := getManagedKeyId(data) + if err != nil { + return nil, errors.New("unable to determine managed key id") + } + // Determine key type and key bits from the managed public key + var pubKey crypto.PublicKey + err = withManagedPKIKey(ctx, b, keyId, mountPoint, func(ctx context.Context, key logical.ManagedSigningKey) error { + pubKey, err = key.GetPublicKey(ctx) + if err != nil { + return err + } + + return nil + }) + if err != nil { + return nil, errors.New("failed to lookup public key from managed key: " + err.Error()) + } + return pubKey, nil +} diff --git a/builtin/logical/pki/config_util.go b/builtin/logical/pki/config_util.go index 2ba36fe9d0fd..830590fbd8b0 100644 --- a/builtin/logical/pki/config_util.go +++ b/builtin/logical/pki/config_util.go @@ -2,10 +2,29 @@ package pki import ( "context" + "strings" "github.com/hashicorp/vault/sdk/logical" ) +func isKeyDefaultSet(ctx context.Context, s logical.Storage) (bool, error) { + config, err := getKeysConfig(ctx, s) + if err != nil { + return false, err + } + + return strings.TrimSpace(config.DefaultKeyId.String()) != "", nil +} + +func isIssuerDefaultSet(ctx context.Context, s logical.Storage) (bool, error) { + config, err := getIssuersConfig(ctx, s) + if err != nil { + return false, err + } + + return strings.TrimSpace(config.DefaultIssuerId.String()) != "", nil +} + func updateDefaultKeyId(ctx context.Context, s logical.Storage, id keyId) error { config, err := getKeysConfig(ctx, s) if err != nil { diff --git a/builtin/logical/pki/fields.go b/builtin/logical/pki/fields.go index 4593d6d9a7cc..fa94751451b7 100644 --- a/builtin/logical/pki/fields.go +++ b/builtin/logical/pki/fields.go @@ -314,6 +314,18 @@ SHA-2-512. Defaults to 0 to automatically detect based on key length Value: "rsa", }, } + + fields["key_id"] = &framework.FieldSchema{ + Type: framework.TypeString, + Description: `Reference to a existing key; either "default" +for the configured default key, an identifier or the name assigned +to the key. Note this is only used for the existing generation type.`, + } + + fields["id"] = &framework.FieldSchema{ + Type: framework.TypeString, + Description: `Assign a name to the generated issuer.`, + } return fields } diff --git a/builtin/logical/pki/managed_key_util.go b/builtin/logical/pki/managed_key_util.go index 4c16d0d2a410..acb11e624065 100644 --- a/builtin/logical/pki/managed_key_util.go +++ b/builtin/logical/pki/managed_key_util.go @@ -4,6 +4,7 @@ package pki import ( "context" + "encoding/pem" "errors" "io" @@ -13,10 +14,17 @@ import ( var errEntOnly = errors.New("managed keys are supported within enterprise edition only") -func generateCABundle(_ context.Context, _ *backend, input *inputBundle, data *certutil.CreationBundle, randomSource io.Reader) (*certutil.ParsedCertBundle, error) { +func generateCABundle(ctx context.Context, _ *backend, input *inputBundle, data *certutil.CreationBundle, randomSource io.Reader) (*certutil.ParsedCertBundle, error) { if kmsRequested(input) { return nil, errEntOnly } + if existingKeyRequested(input) { + keyRef, err := getExistingKeyRef(input.apiData) + if err != nil { + return nil, err + } + return certutil.CreateCertificateWithKeyGenerator(data, randomSource, existingGeneratePrivateKey(ctx, input.req.Storage, keyRef)) + } return certutil.CreateCertificateWithRandomSource(data, randomSource) } @@ -35,3 +43,27 @@ func parseCABundle(_ context.Context, _ *backend, _ *logical.Request, bundle *ce func withManagedPKIKey(_ context.Context, _ *backend, _ managedKeyId, _ string, _ logical.ManagedSigningKeyConsumer) error { return errEntOnly } + +func existingGeneratePrivateKey(ctx context.Context, s logical.Storage, keyRef string) certutil.KeyGenerator { + return func(keyType string, keyBits int, container certutil.ParsedPrivateKeyContainer, _ io.Reader) error { + keyId, err := resolveKeyReference(ctx, s, keyRef) + if err != nil { + return err + } + key, err := fetchKeyById(ctx, s, keyId) + if err != nil { + return err + } + signer, err := key.GetSigner() + if err != nil { + return err + } + privateKeyType := certutil.GetPrivateKeyTypeFromSigner(signer) + if privateKeyType == certutil.UnknownPrivateKey { + return errors.New("unknown private key type loaded from key id: " + keyId.String()) + } + blk, _ := pem.Decode([]byte(key.PrivateKey)) + container.SetParsedPrivateKey(signer, privateKeyType, blk.Bytes) + return nil + } +} diff --git a/builtin/logical/pki/path_manage_issuers.go b/builtin/logical/pki/path_manage_issuers.go index 907f8aa8da4b..b9ac8bc5740b 100644 --- a/builtin/logical/pki/path_manage_issuers.go +++ b/builtin/logical/pki/path_manage_issuers.go @@ -11,6 +11,30 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) +func pathIssuerGenerateRoot(b *backend) *framework.Path { + ret := &framework.Path{ + Pattern: "issuers/generate/root/" + framework.GenericNameRegex("exported"), + + Operations: map[logical.Operation]framework.OperationHandler{ + logical.UpdateOperation: &framework.PathOperation{ + Callback: b.pathCAGenerateRoot, + // Read more about why these flags are set in backend.go + ForwardPerformanceStandby: true, + ForwardPerformanceSecondary: true, + }, + }, + + HelpSynopsis: pathGenerateRootHelpSyn, + HelpDescription: pathGenerateRootHelpDesc, + } + + ret.Fields = addCACommonFields(map[string]*framework.FieldSchema{}) + ret.Fields = addCAKeyGenerationFields(ret.Fields) + ret.Fields = addCAIssueFields(ret.Fields) + + return ret +} + func pathImportIssuer(b *backend) *framework.Path { return &framework.Path{ Pattern: "issuers/import/(cert|bundle)", @@ -88,7 +112,7 @@ func (b *backend) pathImportIssuers(ctx context.Context, req *logical.Request, d } for _, certPem := range issuers { - cert, existing, err := importIssuer(ctx, req.Storage, certPem) + cert, existing, err := importIssuer(ctx, req.Storage, certPem, "") if err != nil { return logical.ErrorResponse(err.Error()), nil } diff --git a/builtin/logical/pki/path_root.go b/builtin/logical/pki/path_root.go index 75d1dfda002f..5f84c1152358 100644 --- a/builtin/logical/pki/path_root.go +++ b/builtin/logical/pki/path_root.go @@ -97,6 +97,12 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request, role.MaxPathLength = &maxPathLength } + issuerName := "" + issuerNameIface, ok := data.GetOk("id") + if ok { + issuerName = issuerNameIface.(string) + } + input := &inputBundle{ req: req, apiData: data, @@ -163,14 +169,12 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request, } // Store it as the CA bundle - entry, err = logical.StorageEntryJSON("config/ca_bundle", cb) - if err != nil { - return nil, err - } - err = req.Storage.Put(ctx, entry) + myIssuer, myKey, err := writeCaBundle(ctx, req.Storage, cb, issuerName) if err != nil { return nil, err } + resp.Data["id"] = myIssuer.ID + resp.Data["key_id"] = myKey.ID // Also store it as just the certificate identified by serial number, so it // can be revoked @@ -184,9 +188,10 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request, // For ease of later use, also store just the certificate at a known // location - entry.Key = "ca" - entry.Value = parsedBundle.CertificateBytes - err = req.Storage.Put(ctx, entry) + err = req.Storage.Put(ctx, &logical.StorageEntry{ + Key: "ca", + Value: parsedBundle.CertificateBytes, + }) if err != nil { return nil, err } diff --git a/builtin/logical/pki/storage.go b/builtin/logical/pki/storage.go index 4b6d5b23930e..8799caf1112e 100644 --- a/builtin/logical/pki/storage.go +++ b/builtin/logical/pki/storage.go @@ -306,7 +306,7 @@ func deleteIssuer(ctx context.Context, s logical.Storage, id issuerId) error { return s.Delete(ctx, issuerPrefix+id.String()) } -func importIssuer(ctx context.Context, s logical.Storage, certValue string) (*issuer, bool, error) { +func importIssuer(ctx context.Context, s logical.Storage, certValue string, issuerName string) (*issuer, bool, error) { // importIssuers imports the specified PEM-format certificate (from // certValue) into the new PKI storage format. The first return field is a // reference to the new issuer; the second is whether or not the issuer @@ -349,6 +349,7 @@ func importIssuer(ctx context.Context, s logical.Storage, certValue string) (*is // storage. var result issuer result.ID = genIssuerId() + result.Name = issuerName result.Certificate = certValue result.CAChain = []string{certValue} @@ -518,6 +519,56 @@ func fetchCertBundleByIssuerId(ctx context.Context, s logical.Storage, id issuer return &bundle, nil } +func writeCaBundle(ctx context.Context, s logical.Storage, caBundle *certutil.CertBundle, issuerName string) (*issuer, *key, error) { + allKeyIds, err := listKeys(ctx, s) + if err != nil { + return nil, nil, err + } + + allIssuerIds, err := listIssuers(ctx, s) + if err != nil { + return nil, nil, err + } + + myKey, _, err := importKey(ctx, s, caBundle.PrivateKey) + if err != nil { + return nil, nil, err + } + + myIssuer, _, err := importIssuer(ctx, s, caBundle.Certificate, issuerName) + if err != nil { + return nil, nil, err + } + + for _, cert := range caBundle.CAChain { + if _, _, err = importIssuer(ctx, s, cert, ""); err != nil { + return nil, nil, err + } + } + + keyDefaultSet, err := isKeyDefaultSet(ctx, s) + if err != nil { + return nil, nil, err + } + if len(allKeyIds) == 0 || !keyDefaultSet { + if err = updateDefaultKeyId(ctx, s, myKey.ID); err != nil { + return nil, nil, err + } + } + + issuerDefaultSet, err := isIssuerDefaultSet(ctx, s) + if err != nil { + return nil, nil, err + } + if len(allIssuerIds) == 0 || !issuerDefaultSet { + if err = updateDefaultIssuerId(ctx, s, myIssuer.ID); err != nil { + return nil, nil, err + } + } + + return myIssuer, myKey, nil +} + func genIssuerId() issuerId { return issuerId(genUuid()) } diff --git a/builtin/logical/pki/storage_migrations.go b/builtin/logical/pki/storage_migrations.go index 73882f57461f..c1e05263a77e 100644 --- a/builtin/logical/pki/storage_migrations.go +++ b/builtin/logical/pki/storage_migrations.go @@ -50,10 +50,12 @@ func migrateStorage(ctx context.Context, req *logical.InitializationRequest, log logger.Warn("performing PKI migration to new keys/issuers layout") - err = migrateToIssuers(ctx, s, legacyBundle) + anIssuer, aKey, err := writeCaBundle(ctx, s, legacyBundle, "") if err != nil { return err } + logger.Info("Migration generated the following ids and set them as defaults", + "issuer id", anIssuer.ID, "key id", aKey.ID) err = setLegacyBundleMigrationLog(ctx, s, &legacyBundleMigration{ hash: hash, @@ -82,35 +84,6 @@ func computeHashOfLegacyBundle(bundle *certutil.CertBundle) (string, error) { return hex.EncodeToString(hasher.Sum(nil)), nil } -func migrateToIssuers(ctx context.Context, s logical.Storage, bundle *certutil.CertBundle) error { - defaultKey, _, err := importKey(ctx, s, bundle.PrivateKey) - if err != nil { - return err - } - - defaultIssuer, _, err := importIssuer(ctx, s, bundle.Certificate) - if err != nil { - return err - } - - for _, cert := range bundle.CAChain { - if _, _, err = importIssuer(ctx, s, cert); err != nil { - return err - } - } - - if err = updateDefaultKeyId(ctx, s, defaultKey.ID); err != nil { - return err - } - - if err = updateDefaultIssuerId(ctx, s, defaultIssuer.ID); err != nil { - return err - } - - // FIXME: Call function that will recompute the CAChain on issuers here. - return nil -} - type legacyBundleMigration struct { hash string created time.Time diff --git a/builtin/logical/pki/storage_test.go b/builtin/logical/pki/storage_test.go index f8ec0e8978a4..b96118a44827 100644 --- a/builtin/logical/pki/storage_test.go +++ b/builtin/logical/pki/storage_test.go @@ -113,13 +113,13 @@ func Test_KeysIssuerImport(t *testing.T) { require.Equal(t, key1.PrivateKey, key1_ref1.PrivateKey) require.Equal(t, key1_ref1.ID, key1_ref2.ID) - issuer1_ref1, existing, err := importIssuer(ctx, s, issuer1.Certificate) + issuer1_ref1, existing, err := importIssuer(ctx, s, issuer1.Certificate, "") require.NoError(t, err) require.False(t, existing) require.Equal(t, issuer1.Certificate, issuer1_ref1.Certificate) require.Equal(t, key1_ref1.ID, issuer1_ref1.KeyID) - issuer1_ref2, existing, err := importIssuer(ctx, s, issuer1.Certificate) + issuer1_ref2, existing, err := importIssuer(ctx, s, issuer1.Certificate, "") require.NoError(t, err) require.True(t, existing) require.Equal(t, issuer1.Certificate, issuer1_ref1.Certificate) @@ -132,7 +132,7 @@ func Test_KeysIssuerImport(t *testing.T) { err = writeKey(ctx, s, key2) require.NoError(t, err) - issuer2_ref, existing, err := importIssuer(ctx, s, issuer2.Certificate) + issuer2_ref, existing, err := importIssuer(ctx, s, issuer2.Certificate, "") require.NoError(t, err) require.True(t, existing) require.Equal(t, issuer2.Certificate, issuer2_ref.Certificate) diff --git a/builtin/logical/pki/util.go b/builtin/logical/pki/util.go index 3ed0f14bc673..f6737bdbd149 100644 --- a/builtin/logical/pki/util.go +++ b/builtin/logical/pki/util.go @@ -23,13 +23,29 @@ func denormalizeSerial(serial string) string { } func kmsRequested(input *inputBundle) bool { - exportedStr, ok := input.apiData.GetOk("exported") + return kmsRequestedFromFieldData(input.apiData) +} + +func kmsRequestedFromFieldData(data *framework.FieldData) bool { + exportedStr, ok := data.GetOk("exported") if !ok { return false } return exportedStr.(string) == "kms" } +func existingKeyRequested(input *inputBundle) bool { + return existingKeyRequestedFromFieldData(input.apiData) +} + +func existingKeyRequestedFromFieldData(data *framework.FieldData) bool { + exportedStr, ok := data.GetOk("exported") + if !ok { + return false + } + return exportedStr.(string) == "existing" +} + type managedKeyId interface { String() string } @@ -63,6 +79,19 @@ func getManagedKeyId(data *framework.FieldData) (managedKeyId, error) { return keyId, nil } +func getExistingKeyRef(data *framework.FieldData) (string, error) { + keyRef, ok := data.GetOk("key_id") + if !ok { + return "", errutil.UserError{Err: fmt.Sprintf("missing argument key_id for existing type")} + } + trimmedKeyRef := strings.TrimSpace(keyRef.(string)) + if len(trimmedKeyRef) == 0 { + return "", errutil.UserError{Err: fmt.Sprintf("missing argument key_id for existing type")} + } + + return trimmedKeyRef, nil +} + func getManagedKeyNameOrUUID(data *framework.FieldData) (name string, UUID string, err error) { getApiData := func(argName string) (string, error) { arg, ok := data.GetOk(argName)