diff --git a/builtin/credential/cert/backend_test.go b/builtin/credential/cert/backend_test.go index eb063491d968..8e2db8dfc312 100644 --- a/builtin/credential/cert/backend_test.go +++ b/builtin/credential/cert/backend_test.go @@ -11,6 +11,7 @@ import ( "net/url" "path/filepath" + "github.com/go-test/deep" "github.com/hashicorp/go-sockaddr" "golang.org/x/net/http2" @@ -39,6 +40,7 @@ import ( logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/certutil" + "github.com/hashicorp/vault/sdk/helper/tokenutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" "github.com/mitchellh/mapstructure" @@ -1949,3 +1951,60 @@ func Test_Renew(t *testing.T) { t.Fatal("expected error") } } + +func TestBackend_CertUpgrade(t *testing.T) { + s := &logical.InmemStorage{} + + config := logical.TestBackendConfig() + config.StorageView = s + + ctx := context.Background() + + b := Backend() + if b == nil { + t.Fatalf("failed to create backend") + } + if err := b.Setup(ctx, config); err != nil { + t.Fatal(err) + } + + foo := &CertEntry{ + Policies: []string{"foo"}, + Period: time.Second, + TTL: time.Second, + MaxTTL: time.Second, + BoundCIDRs: []*sockaddr.SockAddrMarshaler{&sockaddr.SockAddrMarshaler{SockAddr: sockaddr.MustIPAddr("127.0.0.1")}}, + } + + entry, err := logical.StorageEntryJSON("cert/foo", foo) + if err != nil { + t.Fatal(err) + } + err = s.Put(ctx, entry) + if err != nil { + t.Fatal(err) + } + + certEntry, err := b.Cert(ctx, s, "foo") + if err != nil { + t.Fatal(err) + } + + exp := &CertEntry{ + Policies: []string{"foo"}, + Period: time.Second, + TTL: time.Second, + MaxTTL: time.Second, + BoundCIDRs: []*sockaddr.SockAddrMarshaler{&sockaddr.SockAddrMarshaler{SockAddr: sockaddr.MustIPAddr("127.0.0.1")}}, + TokenParams: tokenutil.TokenParams{ + TokenPolicies: []string{"foo"}, + TokenPeriod: time.Second, + TokenTTL: time.Second, + TokenMaxTTL: time.Second, + TokenBoundCIDRs: []*sockaddr.SockAddrMarshaler{&sockaddr.SockAddrMarshaler{SockAddr: sockaddr.MustIPAddr("127.0.0.1")}}, + }, + } + if diff := deep.Equal(certEntry, exp); diff != nil { + t.Fatal(diff) + } +} diff --git a/builtin/credential/cert/path_certs.go b/builtin/credential/cert/path_certs.go index dfbee218c6a5..a286e8b11e7a 100644 --- a/builtin/credential/cert/path_certs.go +++ b/builtin/credential/cert/path_certs.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/parseutil" "github.com/hashicorp/vault/sdk/helper/policyutil" + "github.com/hashicorp/vault/sdk/helper/tokenutil" "github.com/hashicorp/vault/sdk/logical" ) @@ -28,7 +29,7 @@ func pathListCerts(b *backend) *framework.Path { } func pathCerts(b *backend) *framework.Path { - return &framework.Path{ + p := &framework.Path{ Pattern: "certs/" + framework.GenericNameRegex("name"), Fields: map[string]*framework.FieldSchema{ "name": &framework.FieldSchema{ @@ -95,39 +96,38 @@ certificate.`, "policies": &framework.FieldSchema{ Type: framework.TypeCommaStringSlice, - Description: "Comma-separated list of policies.", + Description: tokenutil.DeprecationText("token_policies"), + Deprecated: true, }, "lease": &framework.FieldSchema{ - Type: framework.TypeInt, - Description: `Deprecated: use "ttl" instead. TTL time in -seconds. Defaults to system/backend default TTL.`, + Type: framework.TypeInt, + Description: tokenutil.DeprecationText("token_ttl"), + Deprecated: true, }, "ttl": &framework.FieldSchema{ - Type: framework.TypeDurationSecond, - Description: `TTL for tokens issued by this backend. -Defaults to system/backend default TTL time.`, + Type: framework.TypeDurationSecond, + Description: tokenutil.DeprecationText("token_ttl"), + Deprecated: true, }, "max_ttl": &framework.FieldSchema{ - Type: framework.TypeDurationSecond, - Description: `Duration in either an integer number of seconds (3600) or -an integer time unit (60m) after which the -issued token can no longer be renewed.`, + Type: framework.TypeDurationSecond, + Description: tokenutil.DeprecationText("token_max_ttl"), + Deprecated: true, }, "period": &framework.FieldSchema{ - Type: framework.TypeDurationSecond, - Description: `If set, indicates that the token generated using this role -should never expire. The token should be renewed within the -duration specified by this value. At each renewal, the token's -TTL will be set to the value of this parameter.`, + Type: framework.TypeDurationSecond, + Description: tokenutil.DeprecationText("token_period"), + Deprecated: true, }, + "bound_cidrs": &framework.FieldSchema{ - Type: framework.TypeCommaStringSlice, - Description: `Comma separated string or list of CIDR blocks. If set, specifies the blocks of -IP addresses which can perform the login operation.`, + Type: framework.TypeCommaStringSlice, + Description: tokenutil.DeprecationText("token_bound_cidrs"), + Deprecated: true, }, }, @@ -140,6 +140,9 @@ IP addresses which can perform the login operation.`, HelpSynopsis: pathCertHelpSyn, HelpDescription: pathCertHelpDesc, } + + tokenutil.AddTokenFields(p.Fields) + return p } func (b *backend) Cert(ctx context.Context, s logical.Storage, n string) (*CertEntry, error) { @@ -155,6 +158,23 @@ func (b *backend) Cert(ctx context.Context, s logical.Storage, n string) (*CertE if err := entry.DecodeJSON(&result); err != nil { return nil, err } + + if result.TokenTTL == 0 && result.TTL > 0 { + result.TokenTTL = result.TTL + } + if result.TokenMaxTTL == 0 && result.MaxTTL > 0 { + result.TokenMaxTTL = result.MaxTTL + } + if result.TokenPeriod == 0 && result.Period > 0 { + result.TokenPeriod = result.Period + } + if len(result.TokenPolicies) == 0 && len(result.Policies) > 0 { + result.TokenPolicies = result.Policies + } + if len(result.TokenBoundCIDRs) == 0 && len(result.BoundCIDRs) > 0 { + result.TokenBoundCIDRs = result.BoundCIDRs + } + return &result, nil } @@ -183,86 +203,202 @@ func (b *backend) pathCertRead(ctx context.Context, req *logical.Request, d *fra return nil, nil } + data := map[string]interface{}{ + "certificate": cert.Certificate, + "display_name": cert.DisplayName, + "allowed_names": cert.AllowedNames, + "allowed_common_names": cert.AllowedCommonNames, + "allowed_dns_sans": cert.AllowedDNSSANs, + "allowed_email_sans": cert.AllowedEmailSANs, + "allowed_uri_sans": cert.AllowedURISANs, + "allowed_organizational_units": cert.AllowedOrganizationalUnits, + "required_extensions": cert.RequiredExtensions, + } + cert.PopulateTokenData(data) + + if cert.TTL > 0 { + data["ttl"] = int64(cert.TTL.Seconds()) + } + if cert.MaxTTL > 0 { + data["max_ttl"] = int64(cert.MaxTTL.Seconds()) + } + if cert.Period > 0 { + data["period"] = int64(cert.Period.Seconds()) + } + if len(cert.Policies) > 0 { + data["policies"] = data["token_policies"] + } + if len(cert.BoundCIDRs) > 0 { + data["bound_cidrs"] = data["token_bound_cidrs"] + } + return &logical.Response{ - Data: map[string]interface{}{ - "certificate": cert.Certificate, - "display_name": cert.DisplayName, - "policies": cert.Policies, - "ttl": cert.TTL / time.Second, - "max_ttl": cert.MaxTTL / time.Second, - "period": cert.Period / time.Second, - "allowed_names": cert.AllowedNames, - "allowed_common_names": cert.AllowedCommonNames, - "allowed_dns_sans": cert.AllowedDNSSANs, - "allowed_email_sans": cert.AllowedEmailSANs, - "allowed_uri_sans": cert.AllowedURISANs, - "allowed_organizational_units": cert.AllowedOrganizationalUnits, - "required_extensions": cert.RequiredExtensions, - "bound_cidrs": cert.BoundCIDRs, - }, + Data: data, }, nil } func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := strings.ToLower(d.Get("name").(string)) - certificate := d.Get("certificate").(string) - displayName := d.Get("display_name").(string) - policies := policyutil.ParsePolicies(d.Get("policies")) - allowedNames := d.Get("allowed_names").([]string) - allowedCommonNames := d.Get("allowed_common_names").([]string) - allowedDNSSANs := d.Get("allowed_dns_sans").([]string) - allowedEmailSANs := d.Get("allowed_email_sans").([]string) - allowedURISANs := d.Get("allowed_uri_sans").([]string) - allowedOrganizationalUnits := d.Get("allowed_organizational_units").([]string) - requiredExtensions := d.Get("required_extensions").([]string) - var resp logical.Response - - // Parse the ttl (or lease duration) - systemDefaultTTL := b.System().DefaultLeaseTTL() - ttl := time.Duration(d.Get("ttl").(int)) * time.Second - if ttl == 0 { - ttl = time.Duration(d.Get("lease").(int)) * time.Second - } - if ttl > systemDefaultTTL { - resp.AddWarning(fmt.Sprintf("Given ttl of %d seconds is greater than current mount/system default of %d seconds", ttl/time.Second, systemDefaultTTL/time.Second)) + cert, err := b.Cert(ctx, req.Storage, name) + if err != nil { + return nil, err } - if ttl < time.Duration(0) { - return logical.ErrorResponse("ttl cannot be negative"), nil + if cert == nil { + cert = &CertEntry{ + Name: name, + } } - // Parse max_ttl - systemMaxTTL := b.System().MaxLeaseTTL() - maxTTL := time.Duration(d.Get("max_ttl").(int)) * time.Second - if maxTTL > systemMaxTTL { - resp.AddWarning(fmt.Sprintf("Given max_ttl of %d seconds is greater than current mount/system default of %d seconds", maxTTL/time.Second, systemMaxTTL/time.Second)) + // Get non tokenutil fields + if certificateRaw, ok := d.GetOk("certificate"); ok { + cert.Certificate = certificateRaw.(string) } - - if maxTTL < time.Duration(0) { - return logical.ErrorResponse("max_ttl cannot be negative"), nil + if displayNameRaw, ok := d.GetOk("display_name"); ok { + cert.DisplayName = displayNameRaw.(string) + } + if allowedNamesRaw, ok := d.GetOk("allowed_names"); ok { + cert.AllowedNames = allowedNamesRaw.([]string) + } + if allowedCommonNamesRaw, ok := d.GetOk("allowed_common_names"); ok { + cert.AllowedCommonNames = allowedCommonNamesRaw.([]string) + } + if allowedDNSSANsRaw, ok := d.GetOk("allowed_dns_sans"); ok { + cert.AllowedDNSSANs = allowedDNSSANsRaw.([]string) + } + if allowedEmailSANsRaw, ok := d.GetOk("allowed_email_sans"); ok { + cert.AllowedEmailSANs = allowedEmailSANsRaw.([]string) + } + if allowedURISANsRaw, ok := d.GetOk("allowed_uri_sans"); ok { + cert.AllowedURISANs = allowedURISANsRaw.([]string) + } + if allowedOrganizationalUnitsRaw, ok := d.GetOk("allowed_organizational_units"); ok { + cert.AllowedOrganizationalUnits = allowedOrganizationalUnitsRaw.([]string) + } + if requiredExtensionsRaw, ok := d.GetOk("required_extensions"); ok { + cert.RequiredExtensions = requiredExtensionsRaw.([]string) } - if maxTTL != 0 && ttl > maxTTL { - return logical.ErrorResponse("ttl should be shorter than max_ttl"), nil + // Get tokenutil fields + if err := cert.ParseTokenFields(req, d); err != nil { + return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest } - // Parse period - period := time.Duration(d.Get("period").(int)) * time.Second - if period > systemMaxTTL { - resp.AddWarning(fmt.Sprintf("Given period of %d seconds is greater than the backend's maximum TTL of %d seconds", period/time.Second, systemMaxTTL/time.Second)) + // Handle upgrade cases + { + policiesRaw, ok := d.GetOk("token_policies") + if !ok { + policiesRaw, ok = d.GetOk("policies") + if ok { + cert.Policies = policyutil.ParsePolicies(policiesRaw) + cert.TokenPolicies = cert.Policies + } + } else { + _, ok = d.GetOk("policies") + if ok { + cert.Policies = cert.TokenPolicies + } else { + cert.Policies = nil + } + } + + ttlRaw, ok := d.GetOk("token_ttl") + if !ok { + ttlRaw, ok = d.GetOk("ttl") + if !ok { + ttlRaw, ok = d.GetOk("lease") + } + if ok { + cert.TTL = time.Duration(ttlRaw.(int)) * time.Second + cert.TokenTTL = cert.TTL + } + } else { + _, ok = d.GetOk("ttl") + if ok { + cert.TTL = cert.TokenTTL + } else { + cert.TTL = 0 + } + } + + maxTTLRaw, ok := d.GetOk("token_max_ttl") + if !ok { + maxTTLRaw, ok = d.GetOk("max_ttl") + if ok { + cert.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second + cert.TokenMaxTTL = cert.MaxTTL + } + } else { + _, ok = d.GetOk("max_ttl") + if ok { + cert.MaxTTL = cert.TokenMaxTTL + } else { + cert.MaxTTL = 0 + } + } + + periodRaw, ok := d.GetOk("token_period") + if !ok { + periodRaw, ok = d.GetOk("period") + if ok { + cert.Period = time.Duration(periodRaw.(int)) * time.Second + cert.TokenPeriod = cert.Period + } + } else { + _, ok = d.GetOk("period") + if ok { + cert.Period = cert.TokenPeriod + } else { + cert.Period = 0 + } + } + + boundCIDRsRaw, ok := d.GetOk("token_bound_cidrs") + if !ok { + boundCIDRsRaw, ok = d.GetOk("bound_cidrs") + if ok { + boundCIDRs, err := parseutil.ParseAddrs(boundCIDRsRaw) + if err != nil { + return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + } + cert.BoundCIDRs = boundCIDRs + cert.TokenBoundCIDRs = cert.BoundCIDRs + } + } else { + _, ok = d.GetOk("bound_cidrs") + if ok { + cert.BoundCIDRs = cert.TokenBoundCIDRs + } else { + cert.BoundCIDRs = nil + } + } + } - if period < time.Duration(0) { - return logical.ErrorResponse("period cannot be negative"), nil + var resp logical.Response + + systemDefaultTTL := b.System().DefaultLeaseTTL() + if cert.TokenTTL > systemDefaultTTL { + resp.AddWarning(fmt.Sprintf("Given ttl of %d seconds is greater than current mount/system default of %d seconds", cert.TokenTTL/time.Second, systemDefaultTTL/time.Second)) + } + systemMaxTTL := b.System().MaxLeaseTTL() + if cert.TokenMaxTTL > systemMaxTTL { + resp.AddWarning(fmt.Sprintf("Given max_ttl of %d seconds is greater than current mount/system default of %d seconds", cert.TokenMaxTTL/time.Second, systemMaxTTL/time.Second)) + } + if cert.TokenMaxTTL != 0 && cert.TokenTTL > cert.TokenMaxTTL { + return logical.ErrorResponse("ttl should be shorter than max_ttl"), nil + } + if cert.TokenPeriod > systemMaxTTL { + resp.AddWarning(fmt.Sprintf("Given period of %d seconds is greater than the backend's maximum TTL of %d seconds", cert.TokenPeriod/time.Second, systemMaxTTL/time.Second)) } // Default the display name to the certificate name if not given - if displayName == "" { - displayName = name + if cert.DisplayName == "" { + cert.DisplayName = name } - parsed := parsePEM([]byte(certificate)) + parsed := parsePEM([]byte(cert.Certificate)) if len(parsed) == 0 { return logical.ErrorResponse("failed to parse certificate"), nil } @@ -281,31 +417,8 @@ func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *fr } } - parsedCIDRs, err := parseutil.ParseAddrs(d.Get("bound_cidrs")) - if err != nil { - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest - } - - certEntry := &CertEntry{ - Name: name, - Certificate: certificate, - DisplayName: displayName, - Policies: policies, - AllowedNames: allowedNames, - AllowedCommonNames: allowedCommonNames, - AllowedDNSSANs: allowedDNSSANs, - AllowedEmailSANs: allowedEmailSANs, - AllowedURISANs: allowedURISANs, - AllowedOrganizationalUnits: allowedOrganizationalUnits, - RequiredExtensions: requiredExtensions, - TTL: ttl, - MaxTTL: maxTTL, - Period: period, - BoundCIDRs: parsedCIDRs, - } - // Store it - entry, err := logical.StorageEntryJSON("cert/"+name, certEntry) + entry, err := logical.StorageEntryJSON("cert/"+name, cert) if err != nil { return nil, err } @@ -321,6 +434,8 @@ func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *fr } type CertEntry struct { + tokenutil.TokenParams + Name string Certificate string DisplayName string diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index c5cf0eaef388..de5db78d66c3 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -83,36 +83,28 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra skid := base64.StdEncoding.EncodeToString(clientCerts[0].SubjectKeyId) akid := base64.StdEncoding.EncodeToString(clientCerts[0].AuthorityKeyId) - resp := &logical.Response{ - Auth: &logical.Auth{ - Period: matched.Entry.Period, - InternalData: map[string]interface{}{ - "subject_key_id": skid, - "authority_key_id": akid, - }, - Policies: matched.Entry.Policies, - DisplayName: matched.Entry.DisplayName, - Metadata: map[string]string{ - "cert_name": matched.Entry.Name, - "common_name": clientCerts[0].Subject.CommonName, - "serial_number": clientCerts[0].SerialNumber.String(), - "subject_key_id": certutil.GetHexFormatted(clientCerts[0].SubjectKeyId, ":"), - "authority_key_id": certutil.GetHexFormatted(clientCerts[0].AuthorityKeyId, ":"), - }, - LeaseOptions: logical.LeaseOptions{ - Renewable: true, - TTL: matched.Entry.TTL, - MaxTTL: matched.Entry.MaxTTL, - }, - Alias: &logical.Alias{ - Name: clientCerts[0].Subject.CommonName, - }, - BoundCIDRs: matched.Entry.BoundCIDRs, + auth := &logical.Auth{ + InternalData: map[string]interface{}{ + "subject_key_id": skid, + "authority_key_id": akid, + }, + DisplayName: matched.Entry.DisplayName, + Metadata: map[string]string{ + "cert_name": matched.Entry.Name, + "common_name": clientCerts[0].Subject.CommonName, + "serial_number": clientCerts[0].SerialNumber.String(), + "subject_key_id": certutil.GetHexFormatted(clientCerts[0].SubjectKeyId, ":"), + "authority_key_id": certutil.GetHexFormatted(clientCerts[0].AuthorityKeyId, ":"), + }, + Alias: &logical.Alias{ + Name: clientCerts[0].Subject.CommonName, }, } + matched.Entry.PopulateTokenAuth(auth) - // Generate a response - return resp, nil + return &logical.Response{ + Auth: auth, + }, nil } func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { @@ -159,14 +151,14 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f return nil, nil } - if !policyutil.EquivalentPolicies(cert.Policies, req.Auth.TokenPolicies) { + if !policyutil.EquivalentPolicies(cert.TokenPolicies, req.Auth.TokenPolicies) { return nil, fmt.Errorf("policies have changed, not renewing") } resp := &logical.Response{Auth: req.Auth} - resp.Auth.TTL = cert.TTL - resp.Auth.MaxTTL = cert.MaxTTL - resp.Auth.Period = cert.Period + resp.Auth.TTL = cert.TokenTTL + resp.Auth.MaxTTL = cert.TokenMaxTTL + resp.Auth.Period = cert.TokenPeriod return resp, nil } @@ -478,7 +470,7 @@ func (b *backend) checkForValidChain(chains [][]*x509.Certificate) bool { } func (b *backend) checkCIDR(cert *CertEntry, req *logical.Request) error { - if cidrutil.RemoteAddrIsOk(req.Connection.RemoteAddr, cert.BoundCIDRs) { + if cidrutil.RemoteAddrIsOk(req.Connection.RemoteAddr, cert.TokenBoundCIDRs) { return nil } return logical.ErrPermissionDenied