diff --git a/builtin/credential/github/backend_test.go b/builtin/credential/github/backend_test.go index 35036fd75aff..4e51eedba855 100644 --- a/builtin/credential/github/backend_test.go +++ b/builtin/credential/github/backend_test.go @@ -35,34 +35,35 @@ func TestBackend_Config(t *testing.T) { "ttl": "", "max_ttl": "", } - expectedTTL1, _ := time.ParseDuration("24h0m0s") + expectedTTL1 := 24 * time.Hour config_data2 := map[string]interface{}{ "organization": os.Getenv("GITHUB_ORG"), "ttl": "1h", "max_ttl": "2h", } - expectedTTL2, _ := time.ParseDuration("1h0m0s") + expectedTTL2 := time.Hour config_data3 := map[string]interface{}{ "organization": os.Getenv("GITHUB_ORG"), "ttl": "50h", "max_ttl": "50h", } + expectedTTL3 := 48 * time.Hour logicaltest.Test(t, logicaltest.TestCase{ - PreCheck: func() { testAccPreCheck(t) }, - LogicalBackend: b, + PreCheck: func() { testAccPreCheck(t) }, + CredentialBackend: b, Steps: []logicaltest.TestStep{ testConfigWrite(t, config_data1), - testLoginWrite(t, login_data, expectedTTL1.Nanoseconds(), false), + testLoginWrite(t, login_data, expectedTTL1, false), testConfigWrite(t, config_data2), - testLoginWrite(t, login_data, expectedTTL2.Nanoseconds(), false), + testLoginWrite(t, login_data, expectedTTL2, false), testConfigWrite(t, config_data3), - testLoginWrite(t, login_data, 0, true), + testLoginWrite(t, login_data, expectedTTL3, true), }, }) } -func testLoginWrite(t *testing.T, d map[string]interface{}, expectedTTL int64, expectFail bool) logicaltest.TestStep { +func testLoginWrite(t *testing.T, d map[string]interface{}, expectedTTL time.Duration, expectFail bool) logicaltest.TestStep { return logicaltest.TestStep{ Operation: logical.UpdateOperation, Path: "login", @@ -72,10 +73,9 @@ func testLoginWrite(t *testing.T, d map[string]interface{}, expectedTTL int64, e if resp.IsError() && expectFail { return nil } - var actualTTL int64 - actualTTL = resp.Auth.LeaseOptions.TTL.Nanoseconds() + actualTTL := resp.Auth.LeaseOptions.TTL if actualTTL != expectedTTL { - return fmt.Errorf("TTL mismatched. Expected: %d Actual: %d", expectedTTL, resp.Auth.LeaseOptions.TTL.Nanoseconds()) + return fmt.Errorf("TTL mismatched. Expected: %d Actual: %d", expectedTTL, resp.Auth.LeaseOptions.TTL) } return nil }, @@ -105,25 +105,25 @@ func TestBackend_basic(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - PreCheck: func() { testAccPreCheck(t) }, - LogicalBackend: b, + PreCheck: func() { testAccPreCheck(t) }, + CredentialBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, false), testAccMap(t, "default", "fakepol"), testAccMap(t, "oWnErs", "fakepol"), - testAccLogin(t, []string{"default", "fakepol"}), + testAccLogin(t, []string{"default", "abc", "fakepol"}), testAccStepConfig(t, true), testAccMap(t, "default", "fakepol"), testAccMap(t, "oWnErs", "fakepol"), - testAccLogin(t, []string{"default", "fakepol"}), + testAccLogin(t, []string{"default", "abc", "fakepol"}), testAccStepConfigWithBaseURL(t), testAccMap(t, "default", "fakepol"), testAccMap(t, "oWnErs", "fakepol"), - testAccLogin(t, []string{"default", "fakepol"}), + testAccLogin(t, []string{"default", "abc", "fakepol"}), testAccMap(t, "default", "fakepol"), testAccStepConfig(t, true), mapUserToPolicy(t, os.Getenv("GITHUB_USER"), "userpolicy"), - testAccLogin(t, []string{"default", "fakepol", "userpolicy"}), + testAccLogin(t, []string{"default", "abc", "fakepol", "userpolicy"}), }, }) } @@ -133,6 +133,10 @@ func testAccPreCheck(t *testing.T) { t.Skip("GITHUB_TOKEN must be set for acceptance tests") } + if v := os.Getenv("GITHUB_USER"); v == "" { + t.Skip("GITHUB_USER must be set for acceptance tests") + } + if v := os.Getenv("GITHUB_ORG"); v == "" { t.Skip("GITHUB_ORG must be set for acceptance tests") } @@ -147,7 +151,8 @@ func testAccStepConfig(t *testing.T, upper bool) logicaltest.TestStep { Operation: logical.UpdateOperation, Path: "config", Data: map[string]interface{}{ - "organization": os.Getenv("GITHUB_ORG"), + "organization": os.Getenv("GITHUB_ORG"), + "token_policies": []string{"abc"}, }, } if upper { diff --git a/builtin/credential/github/path_config.go b/builtin/credential/github/path_config.go index 21d89b6f7aaf..be793b79785b 100644 --- a/builtin/credential/github/path_config.go +++ b/builtin/credential/github/path_config.go @@ -4,15 +4,17 @@ import ( "context" "fmt" "net/url" + "strings" "time" "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/helper/tokenutil" "github.com/hashicorp/vault/sdk/logical" ) func pathConfig(b *backend) *framework.Path { - return &framework.Path{ + p := &framework.Path{ Pattern: "config", Fields: map[string]*framework.FieldSchema{ "organization": &framework.FieldSchema{ @@ -31,18 +33,14 @@ API-compatible authentication server.`, }, }, "ttl": &framework.FieldSchema{ - Type: framework.TypeString, - Description: `Duration after which authentication will be expired`, - DisplayAttrs: &framework.DisplayAttributes{ - Name: "TTL", - }, + Type: framework.TypeDurationSecond, + Description: tokenutil.DeprecationText("token_ttl"), + Deprecated: true, }, "max_ttl": &framework.FieldSchema{ - Type: framework.TypeString, - Description: `Maximum duration after which authentication will be expired`, - DisplayAttrs: &framework.DisplayAttributes{ - Name: "Max TTL", - }, + Type: framework.TypeDurationSecond, + Description: tokenutil.DeprecationText("token_max_ttl"), + Deprecated: true, }, }, @@ -51,48 +49,77 @@ API-compatible authentication server.`, logical.ReadOperation: b.pathConfigRead, }, } + + tokenutil.AddTokenFields(p.Fields) + p.Fields["token_policies"].Description += ". This will apply to all tokens generated by this auth method, in addition to any policies configured for specific users/groups." + return p } func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - organization := data.Get("organization").(string) - baseURL := data.Get("base_url").(string) - if len(baseURL) != 0 { + c, err := b.Config(ctx, req.Storage) + if err != nil { + return nil, err + } + if c == nil { + c = &config{} + } + + if organizationRaw, ok := data.GetOk("organization"); ok { + c.Organization = organizationRaw.(string) + } + + if baseURLRaw, ok := data.GetOk("base_url"); ok { + baseURL := baseURLRaw.(string) _, err := url.Parse(baseURL) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error parsing given base_url: %s", err)), nil } + if !strings.HasSuffix(baseURL, "/") { + baseURL += "/" + } + c.BaseURL = baseURL } - var ttl time.Duration - var err error - ttlRaw, ok := data.GetOk("ttl") - if !ok || len(ttlRaw.(string)) == 0 { - ttl = 0 - } else { - ttl, err = time.ParseDuration(ttlRaw.(string)) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf("Invalid 'ttl':%s", err)), nil - } + if err := c.ParseTokenFields(req, data); err != nil { + return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest } - var maxTTL time.Duration - maxTTLRaw, ok := data.GetOk("max_ttl") - if !ok || len(maxTTLRaw.(string)) == 0 { - maxTTL = 0 - } else { - maxTTL, err = time.ParseDuration(maxTTLRaw.(string)) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf("Invalid 'max_ttl':%s", err)), nil + // Handle upgrade cases + { + ttlRaw, ok := data.GetOk("token_ttl") + if !ok { + ttlRaw, ok = data.GetOk("ttl") + if ok { + c.TTL = time.Duration(ttlRaw.(int)) * time.Second + c.TokenTTL = c.TTL + } + } else { + _, ok = data.GetOk("ttl") + if ok { + c.TTL = c.TokenTTL + } else { + c.TTL = 0 + } } - } - entry, err := logical.StorageEntryJSON("config", config{ - Organization: organization, - BaseURL: baseURL, - TTL: ttl, - MaxTTL: maxTTL, - }) + maxTTLRaw, ok := data.GetOk("token_max_ttl") + if !ok { + maxTTLRaw, ok = data.GetOk("max_ttl") + if ok { + c.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second + c.TokenMaxTTL = c.MaxTTL + } + } else { + _, ok = data.GetOk("max_ttl") + if ok { + c.MaxTTL = c.TokenMaxTTL + } else { + c.MaxTTL = 0 + } + } + } + entry, err := logical.StorageEntryJSON("config", c) if err != nil { return nil, err } @@ -109,23 +136,26 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, data if err != nil { return nil, err } - if config == nil { return nil, fmt.Errorf("configuration object not found") } - config.TTL /= time.Second - config.MaxTTL /= time.Second + d := map[string]interface{}{ + "organization": config.Organization, + "base_url": config.BaseURL, + } + config.PopulateTokenData(d) - resp := &logical.Response{ - Data: map[string]interface{}{ - "organization": config.Organization, - "base_url": config.BaseURL, - "ttl": config.TTL, - "max_ttl": config.MaxTTL, - }, + if config.TTL > 0 { + d["ttl"] = int64(config.TTL.Seconds()) } - return resp, nil + if config.MaxTTL > 0 { + d["max_ttl"] = int64(config.MaxTTL.Seconds()) + } + + return &logical.Response{ + Data: d, + }, nil } // Config returns the configuration for this backend. @@ -135,6 +165,10 @@ func (b *backend) Config(ctx context.Context, s logical.Storage) (*config, error return nil, err } + if entry == nil { + return nil, nil + } + var result config if entry != nil { if err := entry.DecodeJSON(&result); err != nil { @@ -142,10 +176,19 @@ func (b *backend) Config(ctx context.Context, s logical.Storage) (*config, error } } + if result.TokenTTL == 0 && result.TTL > 0 { + result.TokenTTL = result.TTL + } + if result.TokenMaxTTL == 0 && result.MaxTTL > 0 { + result.TokenMaxTTL = result.MaxTTL + } + return &result, nil } type config struct { + tokenutil.TokenParams + Organization string `json:"organization" structs:"organization" mapstructure:"organization"` BaseURL string `json:"base_url" structs:"base_url" mapstructure:"base_url"` TTL time.Duration `json:"ttl" structs:"ttl" mapstructure:"ttl"` diff --git a/builtin/credential/github/path_login.go b/builtin/credential/github/path_login.go index 4bf5a27e297b..7856eeb9faba 100644 --- a/builtin/credential/github/path_login.go +++ b/builtin/credential/github/path_login.go @@ -63,31 +63,28 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra verifyResp = verifyResponse } - config, err := b.Config(ctx, req.Storage) - if err != nil { - return nil, err + auth := &logical.Auth{ + InternalData: map[string]interface{}{ + "token": token, + }, + Metadata: map[string]string{ + "username": *verifyResp.User.Login, + "org": *verifyResp.Org.Login, + }, + DisplayName: *verifyResp.User.Login, + Alias: &logical.Alias{ + Name: *verifyResp.User.Login, + }, + } + verifyResp.Config.PopulateTokenAuth(auth) + + // Add in configured policies from user/group mapping + if len(verifyResp.Policies) > 0 { + auth.Policies = append(auth.Policies, verifyResp.Policies...) } resp := &logical.Response{ - Auth: &logical.Auth{ - InternalData: map[string]interface{}{ - "token": token, - }, - Policies: verifyResp.Policies, - Metadata: map[string]string{ - "username": *verifyResp.User.Login, - "org": *verifyResp.Org.Login, - }, - DisplayName: *verifyResp.User.Login, - LeaseOptions: logical.LeaseOptions{ - TTL: config.TTL, - MaxTTL: config.MaxTTL, - Renewable: true, - }, - Alias: &logical.Alias{ - Name: *verifyResp.User.Login, - }, - }, + Auth: auth, } for _, teamName := range verifyResp.TeamNames { @@ -125,14 +122,9 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f return nil, fmt.Errorf("policies do not match") } - config, err := b.Config(ctx, req.Storage) - if err != nil { - return nil, err - } - resp := &logical.Response{Auth: req.Auth} - resp.Auth.TTL = config.TTL - resp.Auth.MaxTTL = config.MaxTTL + resp.Auth.TTL = verifyResp.Config.TokenTTL + resp.Auth.MaxTTL = verifyResp.Config.TokenMaxTTL // Remove old aliases resp.Auth.GroupAliases = nil @@ -151,9 +143,13 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, t if err != nil { return nil, nil, err } + if config == nil { + return nil, logical.ErrorResponse("configuration has not been set"), nil + } + if config.Organization == "" { return nil, logical.ErrorResponse( - "configure the github credential backend first"), nil + "organization not found in configuration"), nil } client, err := b.Client(token) @@ -255,6 +251,7 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, t Org: org, Policies: append(groupPoliciesList, userPoliciesList...), TeamNames: teamNames, + Config: config, }, nil, nil } @@ -263,4 +260,7 @@ type verifyCredentialsResp struct { Org *github.Organization Policies []string TeamNames []string + + // This is just a cache to send back to the caller + Config *config }