Skip to content

Commit

Permalink
WIP: Support root issuer generation
Browse files Browse the repository at this point in the history
  • Loading branch information
stevendpclark committed Apr 8, 2022
1 parent 1b46d1c commit a76ec20
Show file tree
Hide file tree
Showing 12 changed files with 396 additions and 90 deletions.
1 change: 1 addition & 0 deletions builtin/logical/pki/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ func Backend(conf *logical.BackendConfig) *backend {
pathIssuerSignIntermediate(&b),
pathIssuerSignSelfIssued(&b),
pathIssuerSignVerbatim(&b),
pathIssuerGenerateRoot(&b),
pathConfigIssuers(&b),

// Fetch APIs have been lowered to favor the newer issuer API endpoints
Expand Down
95 changes: 95 additions & 0 deletions builtin/logical/pki/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4613,6 +4613,101 @@ func TestBackend_Roles_KeySizeRegression(t *testing.T) {
t.Log(fmt.Sprintf("Key size regression expanded matrix test scenarios: %d", tested))
}

func TestRootWithExistingKey(t *testing.T) {
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"pki": Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()

client := cluster.Cores[0].Client
var err error

mountPKIEndpoint(t, client, "pki-root")

// Fail requests if type is existing, and we specify the key_type param
ctx := context.Background()
_, err = client.Logical().WriteWithContext(ctx, "pki-root/root/generate/existing", map[string]interface{}{
"common_name": "root myvault.com",
"key_type": "rsa",
})
require.Error(t, err)
require.Contains(t, err.Error(), "key_type nor key_bits arguments can be set in this mode")

// Fail requests if type is existing, and we specify the key_bits param
_, err = client.Logical().WriteWithContext(ctx, "pki-root/root/generate/existing", map[string]interface{}{
"common_name": "root myvault.com",
"key_bits": "2048",
})
require.Error(t, err)
require.Contains(t, err.Error(), "key_type nor key_bits arguments can be set in this mode")

// Fail if the specified key does not exist.
_, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/existing", map[string]interface{}{
"common_name": "root myvault.com",
"id": "my-issuer1",
"key_id": "my-key1",
})
require.Error(t, err)
require.Contains(t, err.Error(), "unable to find PKI key for reference: my-key1")

// Create the first CA
resp, err := client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/internal", map[string]interface{}{
"common_name": "root myvault.com",
"key_type": "rsa",
"id": "my-issuer1",
})
require.NoError(t, err)
require.NotNil(t, resp.Data["certificate"])
myIssuerId1 := resp.Data["id"]
myKeyId1 := resp.Data["key_id"]
require.NotEmpty(t, myIssuerId1)
require.NotEmpty(t, myKeyId1)

// Create the second CA
resp, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/internal", map[string]interface{}{
"common_name": "root myvault.com",
"key_type": "rsa",
"id": "my-issuer2",
})
require.NoError(t, err)
require.NotNil(t, resp.Data["certificate"])
myIssuerId2 := resp.Data["id"]
myKeyId2 := resp.Data["key_id"]
require.NotEmpty(t, myIssuerId2)
require.NotEmpty(t, myKeyId2)

// Create a third CA re-using key from CA 1
resp, err = client.Logical().WriteWithContext(ctx, "pki-root/issuers/generate/root/existing", map[string]interface{}{
"common_name": "root myvault.com",
"id": "my-issuer3",
"key_id": myKeyId1,
})
require.NoError(t, err)
require.NotNil(t, resp.Data["certificate"])
myIssuerId3 := resp.Data["id"]
myKeyId3 := resp.Data["key_id"]
require.NotEmpty(t, myIssuerId3)
require.NotEmpty(t, myKeyId3)

require.NotEqual(t, myIssuerId1, myIssuerId2)
require.NotEqual(t, myIssuerId1, myIssuerId3)
require.NotEqual(t, myKeyId1, myKeyId2)
require.Equal(t, myKeyId1, myKeyId3)

resp, err = client.Logical().ListWithContext(ctx, "pki-root/issuers")
require.NoError(t, err)
require.Equal(t, 3, len(resp.Data["keys"].([]interface{})))
require.Contains(t, resp.Data["keys"], myIssuerId1)
require.Contains(t, resp.Data["keys"], myIssuerId2)
require.Contains(t, resp.Data["keys"], myIssuerId3)
}

var (
initTest sync.Once
rsaCAKey string
Expand Down
155 changes: 110 additions & 45 deletions builtin/logical/pki/ca_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package pki

import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"errors"
"fmt"
"time"

Expand All @@ -14,18 +16,17 @@ import (
"github.com/hashicorp/vault/sdk/logical"
)

func (b *backend) getGenerationParams(ctx context.Context,
data *framework.FieldData, mountPoint string,
) (exported bool, format string, role *roleEntry, errorResp *logical.Response) {
func (b *backend) getGenerationParams(ctx context.Context, data *framework.FieldData, mountPoint string) (exported bool, format string, role *roleEntry, errorResp *logical.Response) {
exportedStr := data.Get("exported").(string)
switch exportedStr {
case "exported":
exported = true
case "internal":
case "existing":
case "kms":
default:
errorResp = logical.ErrorResponse(
`the "exported" path parameter must be "internal", "exported" or "kms"`)
`the "exported" path parameter must be "internal", "existing", exported" or "kms"`)
return
}

Expand All @@ -36,46 +37,10 @@ func (b *backend) getGenerationParams(ctx context.Context,
return
}

keyType := data.Get("key_type").(string)
keyBits := data.Get("key_bits").(int)
if exportedStr == "kms" {
_, okKeyType := data.Raw["key_type"]
_, okKeyBits := data.Raw["key_bits"]

if okKeyType || okKeyBits {
errorResp = logical.ErrorResponse(
`invalid parameter for the kms path parameter, key_type nor key_bits arguments can be set in this mode`)
return
}

keyId, err := getManagedKeyId(data)
if err != nil {
errorResp = logical.ErrorResponse("unable to determine managed key id")
return
}
// Determine key type and key bits from the managed public key
err = withManagedPKIKey(ctx, b, keyId, mountPoint, func(ctx context.Context, key logical.ManagedSigningKey) error {
pubKey, err := key.GetPublicKey(ctx)
if err != nil {
return err
}
switch pubKey.(type) {
case *rsa.PublicKey:
keyType = "rsa"
keyBits = pubKey.(*rsa.PublicKey).Size() * 8
case *ecdsa.PublicKey:
keyType = "ec"
case *ed25519.PublicKey:
keyType = "ed25519"
default:
return fmt.Errorf("unsupported public key: %#v", pubKey)
}
return nil
})
if err != nil {
errorResp = logical.ErrorResponse("failed to lookup public key from managed key: %s", err.Error())
return
}
keyType, keyBits, err := getKeyTypeAndBitsForRole(ctx, b, data, mountPoint)
if err != nil {
errorResp = logical.ErrorResponse(err.Error())
return
}

role = &roleEntry{
Expand All @@ -101,10 +66,110 @@ func (b *backend) getGenerationParams(ctx context.Context,
}
*role.AllowWildcardCertificates = true

var err error
if role.KeyBits, role.SignatureBits, err = certutil.ValidateDefaultOrValueKeyTypeSignatureLength(role.KeyType, role.KeyBits, role.SignatureBits); err != nil {
errorResp = logical.ErrorResponse(err.Error())
}

return
}

func getKeyTypeAndBitsForRole(ctx context.Context, b *backend, data *framework.FieldData, mountPoint string) (string, int, error) {
exportedStr := data.Get("exported").(string)
var keyType string
var keyBits int

switch exportedStr {
case "internal":
fallthrough
case "exported":
keyType = data.Get("key_type").(string)
keyBits = data.Get("key_bits").(int)
return keyType, keyBits, nil
}

// existing and kms types don't support providing the key_type and key_bits args.
_, okKeyType := data.Raw["key_type"]
_, okKeyBits := data.Raw["key_bits"]

if okKeyType || okKeyBits {
return "", 0, errors.New("invalid parameter for the kms/existing path parameter, key_type nor key_bits arguments can be set in this mode")
}

var pubKey crypto.PublicKey
if kmsRequestedFromFieldData(data) {
pubKeyManagedKey, err := getManagedKeyPublicKey(ctx, b, data, mountPoint)
if err != nil {
return "", 0, errors.New("failed to lookup public key from managed key: " + err.Error())
}
pubKey = pubKeyManagedKey
}

if existingKeyRequestedFromFieldData(data) {
existingPubKey, err := getExistingPublicKey(ctx, b.storage, data)
if err != nil {
return "", 0, errors.New("failed to lookup public key from existing key: " + err.Error())
}
pubKey = existingPubKey
}

return getKeyTypeAndBitsFromPublicKeyForRole(pubKey)
}

func getExistingPublicKey(ctx context.Context, s logical.Storage, data *framework.FieldData) (crypto.PublicKey, error) {
keyRef, err := getExistingKeyRef(data)
if err != nil {
return nil, err
}
id, err := resolveKeyReference(ctx, s, keyRef)
if err != nil {
return nil, err
}
key, err := fetchKeyById(ctx, s, id)
if err != nil {
return nil, err
}
signer, err := key.GetSigner()
if err != nil {
return nil, err
}
return signer.Public(), nil
}

func getKeyTypeAndBitsFromPublicKeyForRole(pubKey crypto.PublicKey) (string, int, error) {
var keyType string
var keyBits int

switch pubKey.(type) {
case *rsa.PublicKey:
keyType = "rsa"
keyBits = certutil.GetPublicKeySize(pubKey)
case *ecdsa.PublicKey:
keyType = "ec"
case *ed25519.PublicKey:
keyType = "ed25519"
default:
return "", 0, fmt.Errorf("unsupported public key: %#v", pubKey)
}
return keyType, keyBits, nil
}

func getManagedKeyPublicKey(ctx context.Context, b *backend, data *framework.FieldData, mountPoint string) (crypto.PublicKey, error) {
keyId, err := getManagedKeyId(data)
if err != nil {
return nil, errors.New("unable to determine managed key id")
}
// Determine key type and key bits from the managed public key
var pubKey crypto.PublicKey
err = withManagedPKIKey(ctx, b, keyId, mountPoint, func(ctx context.Context, key logical.ManagedSigningKey) error {
pubKey, err = key.GetPublicKey(ctx)
if err != nil {
return err
}

return nil
})
if err != nil {
return nil, errors.New("failed to lookup public key from managed key: " + err.Error())
}
return pubKey, nil
}
19 changes: 19 additions & 0 deletions builtin/logical/pki/config_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,29 @@ package pki

import (
"context"
"strings"

"github.com/hashicorp/vault/sdk/logical"
)

func isKeyDefaultSet(ctx context.Context, s logical.Storage) (bool, error) {
config, err := getKeysConfig(ctx, s)
if err != nil {
return false, err
}

return strings.TrimSpace(config.DefaultKeyId.String()) != "", nil
}

func isIssuerDefaultSet(ctx context.Context, s logical.Storage) (bool, error) {
config, err := getIssuersConfig(ctx, s)
if err != nil {
return false, err
}

return strings.TrimSpace(config.DefaultIssuerId.String()) != "", nil
}

func updateDefaultKeyId(ctx context.Context, s logical.Storage, id keyId) error {
config, err := getKeysConfig(ctx, s)
if err != nil {
Expand Down
12 changes: 12 additions & 0 deletions builtin/logical/pki/fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,18 @@ SHA-2-512. Defaults to 0 to automatically detect based on key length
Value: "rsa",
},
}

fields["key_id"] = &framework.FieldSchema{
Type: framework.TypeString,
Description: `Reference to a existing key; either "default"
for the configured default key, an identifier or the name assigned
to the key. Note this is only used for the existing generation type.`,
}

fields["id"] = &framework.FieldSchema{
Type: framework.TypeString,
Description: `Assign a name to the generated issuer.`,
}
return fields
}

Expand Down
Loading

0 comments on commit a76ec20

Please sign in to comment.