diff --git a/builtin/logical/pki/backend_test.go b/builtin/logical/pki/backend_test.go index d271e7ed844a..8a96b81c2c4d 100644 --- a/builtin/logical/pki/backend_test.go +++ b/builtin/logical/pki/backend_test.go @@ -30,6 +30,7 @@ import ( "github.com/fatih/structs" "github.com/go-test/deep" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/builtin/credential/userpass" logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical" vaulthttp "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/sdk/helper/certutil" @@ -2718,6 +2719,126 @@ func TestBackend_URI_SANs(t *testing.T) { cert.URIs) } } + +func TestBackend_AllowedDomainsTemplate(t *testing.T) { + coreConfig := &vault.CoreConfig{ + CredentialBackends: map[string]logical.Factory{ + "userpass": userpass.Factory, + }, + LogicalBackends: map[string]logical.Factory{ + "pki": Factory, + }, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + client := cluster.Cores[0].Client + + // Write test policy for userpass auth method. + err := client.Sys().PutPolicy("test", ` + path "pki/*" { + capabilities = ["update"] + }`) + if err != nil { + t.Fatal(err) + } + + // Enable userpass auth method. + if err := client.Sys().EnableAuth("userpass", "userpass", ""); err != nil { + t.Fatal(err) + } + + // Configure test role for userpass. + if _, err := client.Logical().Write("auth/userpass/users/userpassname", map[string]interface{}{ + "password": "test", + "policies": "test", + }); err != nil { + t.Fatal(err) + } + + // Login userpass for test role and keep client token. + secret, err := client.Logical().Write("auth/userpass/login/userpassname", map[string]interface{}{ + "password": "test", + }) + if err != nil || secret == nil { + t.Fatal(err) + } + userpassToken := secret.Auth.ClientToken + + // Get auth accessor for identity template. + auths, err := client.Sys().ListAuth() + if err != nil { + t.Fatal(err) + } + userpassAccessor := auths["userpass/"].Accessor + + // Mount PKI. + err = client.Sys().Mount("pki", &api.MountInput{ + Type: "pki", + Config: api.MountConfigInput{ + DefaultLeaseTTL: "16h", + MaxLeaseTTL: "60h", + }, + }) + if err != nil { + t.Fatal(err) + } + + // Generate internal CA. + _, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{ + "ttl": "40h", + "common_name": "myvault.com", + }) + if err != nil { + t.Fatal(err) + } + + // Write role PKI. + _, err = client.Logical().Write("pki/roles/test", map[string]interface{}{ + "allowed_domains": []string{"foobar.com", "zipzap.com", "{{identity.entity.aliases." + userpassAccessor + ".name}}"}, + "allowed_domains_template": true, + "allow_bare_domains": true, + }) + if err != nil { + t.Fatal(err) + } + + // Issue certificate with userpassToken. + client.SetToken(userpassToken) + _, err = client.Logical().Write("pki/issue/test", map[string]interface{}{"common_name": "userpassname"}) + if err != nil { + t.Fatal(err) + } + + // Issue certificate for foobar.com to verify allowed_domain_templae doesnt break plain domains. + _, err = client.Logical().Write("pki/issue/test", map[string]interface{}{"common_name": "foobar.com"}) + if err != nil { + t.Fatal(err) + } + + // Issue certificate for unknown userpassname. + _, err = client.Logical().Write("pki/issue/test", map[string]interface{}{"common_name": "unknownuserpassname"}) + if err == nil { + t.Fatal("expected error") + } + + // Set allowed_domains_template to false. + _, err = client.Logical().Write("pki/roles/test", map[string]interface{}{ + "allowed_domains_template": false, + }) + if err != nil { + t.Fatal(err) + } + + // Issue certificate with userpassToken. + _, err = client.Logical().Write("pki/issue/test", map[string]interface{}{"common_name": "userpassname"}) + if err == nil { + t.Fatal("expected error") + } +} + func setCerts() { cak, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { diff --git a/builtin/logical/pki/cert_util.go b/builtin/logical/pki/cert_util.go index fbf8cdd5b56c..608dbf014c22 100644 --- a/builtin/logical/pki/cert_util.go +++ b/builtin/logical/pki/cert_util.go @@ -178,7 +178,7 @@ func fetchCertBySerial(ctx context.Context, req *logical.Request, prefix, serial // Given a set of requested names for a certificate, verifies that all of them // match the various toggles set in the role for controlling issuance. // If one does not pass, it is returned in the string argument. -func validateNames(data *inputBundle, names []string) string { +func validateNames(b *backend, data *inputBundle, names []string) string { for _, name := range names { sanitizedName := name emailDomain := name @@ -314,6 +314,18 @@ func validateNames(data *inputBundle, names []string) string { continue } + if data.role.AllowedDomainsTemplate { + matched, _ := regexp.MatchString(`^{{.+?}}$`, currDomain) + if matched && data.req.EntityID != "" { + tmpCurrDomain, err := framework.PopulateIdentityTemplate(currDomain, data.req.EntityID, b.System()) + if err != nil { + continue + } + + currDomain = tmpCurrDomain + } + } + // First, allow an exact match of the base domain if that role flag // is enabled if data.role.AllowBareDomains && @@ -338,6 +350,7 @@ func validateNames(data *inputBundle, names []string) string { break } } + if valid { continue } @@ -816,7 +829,7 @@ func generateCreationBundle(b *backend, data *inputBundle, caSign *certutil.CAIn // Check the CN. This ensures that the CN is checked even if it's // excluded from SANs. if cn != "" { - badName := validateNames(data, []string{cn}) + badName := validateNames(b, data, []string{cn}) if len(badName) != 0 { return nil, errutil.UserError{Err: fmt.Sprintf( "common name %s not allowed by this role", badName)} @@ -832,13 +845,13 @@ func generateCreationBundle(b *backend, data *inputBundle, caSign *certutil.CAIn } // Check for bad email and/or DNS names - badName := validateNames(data, dnsNames) + badName := validateNames(b, data, dnsNames) if len(badName) != 0 { return nil, errutil.UserError{Err: fmt.Sprintf( "subject alternate name %s not allowed by this role", badName)} } - badName = validateNames(data, emailAddresses) + badName = validateNames(b, data, emailAddresses) if len(badName) != 0 { return nil, errutil.UserError{Err: fmt.Sprintf( "email address %s not allowed by this role", badName)} diff --git a/builtin/logical/pki/path_roles.go b/builtin/logical/pki/path_roles.go index 6b6757631dfc..4d23484079e7 100644 --- a/builtin/logical/pki/path_roles.go +++ b/builtin/logical/pki/path_roles.go @@ -79,7 +79,12 @@ the wildcard subdomains. See the documentation for more information. This parameter accepts a comma-separated string or list of domains.`, }, - + "allowed_domains_template": &framework.FieldSchema{ + Type: framework.TypeBool, + Description: `If set, Allowed domains can be specified using identity template policies. + Non-templated domains are also permitted.`, + Default: false, + }, "allow_bare_domains": &framework.FieldSchema{ Type: framework.TypeBool, Description: `If set, clients can request certificates @@ -541,6 +546,7 @@ func (b *backend) pathRoleCreate(ctx context.Context, req *logical.Request, data TTL: time.Duration(data.Get("ttl").(int)) * time.Second, AllowLocalhost: data.Get("allow_localhost").(bool), AllowedDomains: data.Get("allowed_domains").([]string), + AllowedDomainsTemplate: data.Get("allowed_domains_template").(bool), AllowBareDomains: data.Get("allow_bare_domains").(bool), AllowSubdomains: data.Get("allow_subdomains").(bool), AllowGlobDomains: data.Get("allow_glob_domains").(bool), @@ -728,6 +734,7 @@ type roleEntry struct { AllowedBaseDomain string `json:"allowed_base_domain" mapstructure:"allowed_base_domain"` AllowedDomainsOld string `json:"allowed_domains,omitempty"` AllowedDomains []string `json:"allowed_domains_list" mapstructure:"allowed_domains"` + AllowedDomainsTemplate bool `json:"allowed_domains_template"` AllowBaseDomain bool `json:"allow_base_domain"` AllowBareDomains bool `json:"allow_bare_domains" mapstructure:"allow_bare_domains"` AllowTokenDisplayName bool `json:"allow_token_displayname" mapstructure:"allow_token_displayname"` @@ -778,6 +785,7 @@ func (r *roleEntry) ToResponseData() map[string]interface{} { "max_ttl": int64(r.MaxTTL.Seconds()), "allow_localhost": r.AllowLocalhost, "allowed_domains": r.AllowedDomains, + "allowed_domains_template": r.AllowedDomainsTemplate, "allow_bare_domains": r.AllowBareDomains, "allow_token_displayname": r.AllowTokenDisplayName, "allow_subdomains": r.AllowSubdomains, diff --git a/builtin/logical/pki/path_roles_test.go b/builtin/logical/pki/path_roles_test.go index 01e49983c505..791cc584a232 100644 --- a/builtin/logical/pki/path_roles_test.go +++ b/builtin/logical/pki/path_roles_test.go @@ -586,6 +586,12 @@ func TestPki_RoleNoStore(t *testing.T) { t.Fatalf("no_store should not be set by default") } + // By default, allowed_domains_template should be `false` + allowedDomainsTemplate := resp.Data["allowed_domains_template"].(bool) + if allowedDomainsTemplate { + t.Fatalf("allowed_domains_template should not be set by default") + } + // Make sure that setting no_store to `true` works properly roleReq.Operation = logical.UpdateOperation roleReq.Path = "roles/testrole_nostore"