Skip to content

Commit

Permalink
Add utility methods to fetch common ref and name arguments
Browse files Browse the repository at this point in the history
 - Add utility methods to fetch the issuer_name, issuer_ref, key_name and key_ref arguments from data fields.
 - Centralize the logic to clean up these inputs and apply various validations to all of them.
  • Loading branch information
stevendpclark committed Apr 11, 2022
1 parent e301ef4 commit d2f24d7
Show file tree
Hide file tree
Showing 12 changed files with 170 additions and 44 deletions.
36 changes: 36 additions & 0 deletions builtin/logical/pki/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4656,6 +4656,23 @@ func TestRootWithExistingKey(t *testing.T) {
require.Error(t, err)
require.Contains(t, err.Error(), "unable to find PKI key for reference: my-key1")

// Fail if the specified key name is default.
_, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/internal", map[string]interface{}{
"common_name": "root myvault.com",
"issuer_name": "my-issuer1",
"key_name": "Default",
})
require.Error(t, err)
require.Contains(t, err.Error(), "reserved keyword 'default' can not be used as key name")

// Fail if the specified issuer name is default.
_, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/internal", map[string]interface{}{
"common_name": "root myvault.com",
"issuer_name": "DEFAULT",
})
require.Error(t, err)
require.Contains(t, err.Error(), "reserved keyword 'default' can not be used as issuer name")

// Create the first CA
resp, err := client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/internal", map[string]interface{}{
"common_name": "root myvault.com",
Expand All @@ -4669,11 +4686,20 @@ func TestRootWithExistingKey(t *testing.T) {
require.NotEmpty(t, myIssuerId1)
require.NotEmpty(t, myKeyId1)

// Fail if the specified issuer name is re-used.
_, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/internal", map[string]interface{}{
"common_name": "root myvault.com",
"issuer_name": "my-issuer1",
})
require.Error(t, err)
require.Contains(t, err.Error(), "issuer name already used")

// Create the second CA
resp, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/internal", map[string]interface{}{
"common_name": "root myvault.com",
"key_type": "rsa",
"issuer_name": "my-issuer2",
"key_name": "root-key2",
})
require.NoError(t, err)
require.NotNil(t, resp.Data["certificate"])
Expand All @@ -4682,6 +4708,15 @@ func TestRootWithExistingKey(t *testing.T) {
require.NotEmpty(t, myIssuerId2)
require.NotEmpty(t, myKeyId2)

// Fail if the specified key name is re-used.
_, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/internal", map[string]interface{}{
"common_name": "root myvault.com",
"issuer_name": "my-issuer3",
"key_name": "root-key2",
})
require.Error(t, err)
require.Contains(t, err.Error(), "key name already used")

// Create a third CA re-using key from CA 1
resp, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/existing", map[string]interface{}{
"common_name": "root myvault.com",
Expand Down Expand Up @@ -4764,6 +4799,7 @@ func TestIntermediateWithExistingKey(t *testing.T) {
resp, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/intermediate/internal", map[string]interface{}{
"common_name": "root myvault.com",
"key_type": "rsa",
"key_name": "interkey1",
})
require.NoError(t, err)
// csr2 := resp.Data["csr"]
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/pki/ca_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func getKeyTypeAndBitsForRole(ctx context.Context, b *backend, data *framework.F
}

func getExistingPublicKey(ctx context.Context, s logical.Storage, data *framework.FieldData) (crypto.PublicKey, error) {
keyRef, err := getExistingKeyRef(data)
keyRef, err := getKeyRefWithErr(data)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions builtin/logical/pki/managed_key_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func generateCABundle(ctx context.Context, _ *backend, input *inputBundle, data
return nil, errEntOnly
}
if existingKeyRequested(input) {
keyRef, err := getExistingKeyRef(input.apiData)
keyRef, err := getKeyRefWithErr(input.apiData)
if err != nil {
return nil, err
}
Expand All @@ -33,7 +33,7 @@ func generateCSRBundle(ctx context.Context, _ *backend, input *inputBundle, data
return nil, errEntOnly
}
if existingKeyRequested(input) {
keyRef, err := getExistingKeyRef(input.apiData)
keyRef, err := getKeyRefWithErr(input.apiData)
if err != nil {
return nil, err
}
Expand Down
16 changes: 8 additions & 8 deletions builtin/logical/pki/path_fetch_issuers.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func pathListIssuers(b *backend) *framework.Path {
}
}

func (b *backend) pathListIssuersHandler(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
func (b *backend) pathListIssuersHandler(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
var responseKeys []string
responseInfo := make(map[string]interface{})

Expand Down Expand Up @@ -90,7 +90,7 @@ func (b *backend) pathGetIssuer(ctx context.Context, req *logical.Request, data
return b.pathGetRawIssuer(ctx, req, data)
}

issuerName := data.Get("issuer_ref").(string)
issuerName := getIssuerRef(data)
if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil
}
Expand Down Expand Up @@ -119,14 +119,14 @@ func (b *backend) pathGetIssuer(ctx context.Context, req *logical.Request, data
}

func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
issuerName := data.Get("issuer_ref").(string)
issuerName := getIssuerRef(data)
if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil
}

newName := data.Get("issuer_name").(string)
if len(newName) > 0 && !nameMatcher.MatchString(newName) {
return logical.ErrorResponse("new issuer name outside of valid character limits"), nil
newName, err := getIssuerName(ctx, req.Storage, data)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}

ref, err := resolveIssuerReference(ctx, req.Storage, issuerName)
Expand Down Expand Up @@ -162,7 +162,7 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da
}

func (b *backend) pathGetRawIssuer(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
issuerName := data.Get("issuer_ref").(string)
issuerName := getIssuerRef(data)
if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil
}
Expand Down Expand Up @@ -209,7 +209,7 @@ func (b *backend) pathGetRawIssuer(ctx context.Context, req *logical.Request, da
}

func (b *backend) pathDeleteIssuer(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
issuerName := data.Get("issuer_ref").(string)
issuerName := getIssuerRef(data)
if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil
}
Expand Down
6 changes: 5 additions & 1 deletion builtin/logical/pki/path_intermediate.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ func (b *backend) pathGenerateIntermediate(ctx context.Context, req *logical.Req
return errorResp, nil
}

keyName, err := getKeyName(ctx, req.Storage, data)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
var resp *logical.Response
input := &inputBundle{
role: role,
Expand Down Expand Up @@ -110,7 +114,7 @@ func (b *backend) pathGenerateIntermediate(ctx context.Context, req *logical.Req
}
}

myKey, _, err := importKey(ctx, req.Storage, csrb.PrivateKey)
myKey, _, err := importKey(ctx, req.Storage, csrb.PrivateKey, keyName)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/pki/path_issue_sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func (b *backend) pathIssueSignCert(ctx context.Context, req *logical.Request, d
return nil, logical.ErrReadOnly
}

issuerName := data.Get("issuer_ref").(string)
issuerName := getIssuerRef(data)
if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil
}
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/pki/path_manage_issuers.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (b *backend) pathImportIssuers(ctx context.Context, req *logical.Request, d

for _, keyPem := range keys {
// Handle import of private key.
key, existing, err := importKey(ctx, req.Storage, keyPem)
key, existing, err := importKey(ctx, req.Storage, keyPem, "")
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
Expand Down
17 changes: 10 additions & 7 deletions builtin/logical/pki/path_root.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,13 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request,
role.MaxPathLength = &maxPathLength
}

issuerName := ""
issuerNameIface, ok := data.GetOk("id")
if ok {
issuerName = issuerNameIface.(string)
issuerName, err := getIssuerName(ctx, req.Storage, data)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
keyName, err := getKeyName(ctx, req.Storage, data)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}

input := &inputBundle{
Expand Down Expand Up @@ -149,7 +152,7 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request,
}

// Store it as the CA bundle
myIssuer, myKey, err := writeCaBundle(ctx, req.Storage, cb, issuerName)
myIssuer, myKey, err := writeCaBundle(ctx, req.Storage, cb, issuerName, keyName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -192,7 +195,7 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request,
func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
var err error

issuerName := data.Get("issuer_ref").(string)
issuerName := getIssuerRef(data)
if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil
}
Expand Down Expand Up @@ -341,7 +344,7 @@ func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.R
func (b *backend) pathIssuerSignSelfIssued(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
var err error

issuerName := data.Get("issuer_ref").(string)
issuerName := getIssuerRef(data)
if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil
}
Expand Down
16 changes: 11 additions & 5 deletions builtin/logical/pki/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ func (p issuerId) String() string {
return string(p)
}

const (
IssuerRefNotFound = issuerId("not-found")
KeyRefNotFound = keyId("not-found")
)

type key struct {
ID keyId `json:"id" structs:"id" mapstructure:"id"`
Name string `json:"name" structs:"name" mapstructure:"name"`
Expand Down Expand Up @@ -112,7 +117,7 @@ func deleteKey(ctx context.Context, s logical.Storage, id keyId) error {
return s.Delete(ctx, keyPrefix+id.String())
}

func importKey(ctx context.Context, s logical.Storage, keyValue string) (*key, bool, error) {
func importKey(ctx context.Context, s logical.Storage, keyValue string, keyName string) (*key, bool, error) {
// importKey imports the specified PEM-format key (from keyValue) into
// the new PKI storage format. The first return field is a reference to
// the new key; the second is whether or not the key already existed
Expand Down Expand Up @@ -154,6 +159,7 @@ func importKey(ctx context.Context, s logical.Storage, keyValue string) (*key, b
// Haven't found a key, so we've gotta create it and write it into storage.
var result key
result.ID = genKeyId()
result.Name = keyName
result.PrivateKey = keyValue

// Extracting the signer is necessary for two reasons: first, to get the
Expand Down Expand Up @@ -270,7 +276,7 @@ func resolveKeyReference(ctx context.Context, s logical.Storage, reference strin
}

// Otherwise, we must not have found the key.
return keyId("not-found"), errutil.UserError{Err: fmt.Sprintf("unable to find PKI key for reference: %v", reference)}
return KeyRefNotFound, errutil.UserError{Err: fmt.Sprintf("unable to find PKI key for reference: %v", reference)}
}

func fetchIssuerById(ctx context.Context, s logical.Storage, issuerId issuerId) (*issuer, error) {
Expand Down Expand Up @@ -488,7 +494,7 @@ func resolveIssuerReference(ctx context.Context, s logical.Storage, reference st
}

// Otherwise, we must not have found the issuer.
return issuerId("not-found"), errutil.UserError{Err: fmt.Sprintf("unable to find PKI issuer for reference: %v", reference)}
return IssuerRefNotFound, errutil.UserError{Err: fmt.Sprintf("unable to find PKI issuer for reference: %v", reference)}
}

// Builds a certutil.CertBundle from the specified issuer identifier,
Expand Down Expand Up @@ -519,7 +525,7 @@ func fetchCertBundleByIssuerId(ctx context.Context, s logical.Storage, id issuer
return &bundle, nil
}

func writeCaBundle(ctx context.Context, s logical.Storage, caBundle *certutil.CertBundle, issuerName string) (*issuer, *key, error) {
func writeCaBundle(ctx context.Context, s logical.Storage, caBundle *certutil.CertBundle, issuerName string, keyName string) (*issuer, *key, error) {
allKeyIds, err := listKeys(ctx, s)
if err != nil {
return nil, nil, err
Expand All @@ -530,7 +536,7 @@ func writeCaBundle(ctx context.Context, s logical.Storage, caBundle *certutil.Ce
return nil, nil, err
}

myKey, _, err := importKey(ctx, s, caBundle.PrivateKey)
myKey, _, err := importKey(ctx, s, caBundle.PrivateKey, keyName)
if err != nil {
return nil, nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/pki/storage_migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func migrateStorage(ctx context.Context, req *logical.InitializationRequest, log

logger.Warn("performing PKI migration to new keys/issuers layout")

anIssuer, aKey, err := writeCaBundle(ctx, s, legacyBundle, "")
anIssuer, aKey, err := writeCaBundle(ctx, s, legacyBundle, "", "")
if err != nil {
return err
}
Expand Down
23 changes: 14 additions & 9 deletions builtin/logical/pki/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,48 +102,53 @@ func Test_KeysIssuerImport(t *testing.T) {
issuer1.ID = ""
issuer1.KeyID = ""

key1_ref1, existing, err := importKey(ctx, s, key1.PrivateKey)
key1_ref1, existing, err := importKey(ctx, s, key1.PrivateKey, "key1")
require.NoError(t, err)
require.False(t, existing)
require.Equal(t, key1.PrivateKey, key1_ref1.PrivateKey)

key1_ref2, existing, err := importKey(ctx, s, key1.PrivateKey)
key1_ref2, existing, err := importKey(ctx, s, key1.PrivateKey, "ignore-me")
require.NoError(t, err)
require.True(t, existing)
require.Equal(t, key1.PrivateKey, key1_ref1.PrivateKey)
require.Equal(t, key1_ref1.ID, key1_ref2.ID)
require.Equal(t, key1_ref1.Name, key1_ref2.Name)

issuer1_ref1, existing, err := importIssuer(ctx, s, issuer1.Certificate, "")
issuer1_ref1, existing, err := importIssuer(ctx, s, issuer1.Certificate, "issuer1")
require.NoError(t, err)
require.False(t, existing)
require.Equal(t, issuer1.Certificate, issuer1_ref1.Certificate)
require.Equal(t, key1_ref1.ID, issuer1_ref1.KeyID)
require.Equal(t, "issuer1", issuer1_ref1.Name)

issuer1_ref2, existing, err := importIssuer(ctx, s, issuer1.Certificate, "")
issuer1_ref2, existing, err := importIssuer(ctx, s, issuer1.Certificate, "ignore-me")
require.NoError(t, err)
require.True(t, existing)
require.Equal(t, issuer1.Certificate, issuer1_ref1.Certificate)
require.Equal(t, issuer1_ref1.ID, issuer1_ref2.ID)
require.Equal(t, key1_ref1.ID, issuer1_ref2.KeyID)
require.Equal(t, issuer1_ref1.Name, issuer1_ref2.Name)

err = writeIssuer(ctx, s, &issuer2)
require.NoError(t, err)

err = writeKey(ctx, s, key2)
require.NoError(t, err)

issuer2_ref, existing, err := importIssuer(ctx, s, issuer2.Certificate, "")
issuer2_ref, existing, err := importIssuer(ctx, s, issuer2.Certificate, "ignore-me")
require.NoError(t, err)
require.True(t, existing)
require.Equal(t, issuer2.Certificate, issuer2_ref.Certificate)
require.Equal(t, issuer2_ref.ID, issuer2.ID)
require.Equal(t, issuer2_ref.KeyID, issuer2.KeyID)
require.Equal(t, issuer2.ID, issuer2_ref.ID)
require.Equal(t, "", issuer2_ref.Name)
require.Equal(t, issuer2.KeyID, issuer2_ref.KeyID)

key2_ref, existing, err := importKey(ctx, s, key2.PrivateKey)
key2_ref, existing, err := importKey(ctx, s, key2.PrivateKey, "ignore-me")
require.NoError(t, err)
require.True(t, existing)
require.Equal(t, key2.PrivateKey, key2_ref.PrivateKey)
require.Equal(t, key2_ref.ID, key2.ID)
require.Equal(t, key2.ID, key2_ref.ID)
require.Equal(t, "", key2_ref.Name)
}

func genIssuerAndKey(t *testing.T, b *backend) (issuer, key) {
Expand Down
Loading

0 comments on commit d2f24d7

Please sign in to comment.