Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update github to tokenutil #7031

Merged
merged 3 commits into from
Jul 1, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 23 additions & 18 deletions builtin/credential/github/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,34 +35,35 @@ func TestBackend_Config(t *testing.T) {
"ttl": "",
"max_ttl": "",
}
expectedTTL1, _ := time.ParseDuration("24h0m0s")
expectedTTL1 := 24 * time.Hour
config_data2 := map[string]interface{}{
"organization": os.Getenv("GITHUB_ORG"),
"ttl": "1h",
"max_ttl": "2h",
}
expectedTTL2, _ := time.ParseDuration("1h0m0s")
expectedTTL2 := time.Hour
config_data3 := map[string]interface{}{
"organization": os.Getenv("GITHUB_ORG"),
"ttl": "50h",
"max_ttl": "50h",
}
expectedTTL3 := 48 * time.Hour

logicaltest.Test(t, logicaltest.TestCase{
PreCheck: func() { testAccPreCheck(t) },
LogicalBackend: b,
PreCheck: func() { testAccPreCheck(t) },
CredentialBackend: b,
Steps: []logicaltest.TestStep{
testConfigWrite(t, config_data1),
testLoginWrite(t, login_data, expectedTTL1.Nanoseconds(), false),
testLoginWrite(t, login_data, expectedTTL1, false),
testConfigWrite(t, config_data2),
testLoginWrite(t, login_data, expectedTTL2.Nanoseconds(), false),
testLoginWrite(t, login_data, expectedTTL2, false),
testConfigWrite(t, config_data3),
testLoginWrite(t, login_data, 0, true),
testLoginWrite(t, login_data, expectedTTL3, true),
},
})
}

func testLoginWrite(t *testing.T, d map[string]interface{}, expectedTTL int64, expectFail bool) logicaltest.TestStep {
func testLoginWrite(t *testing.T, d map[string]interface{}, expectedTTL time.Duration, expectFail bool) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "login",
Expand All @@ -72,10 +73,9 @@ func testLoginWrite(t *testing.T, d map[string]interface{}, expectedTTL int64, e
if resp.IsError() && expectFail {
return nil
}
var actualTTL int64
actualTTL = resp.Auth.LeaseOptions.TTL.Nanoseconds()
actualTTL := resp.Auth.LeaseOptions.TTL
if actualTTL != expectedTTL {
return fmt.Errorf("TTL mismatched. Expected: %d Actual: %d", expectedTTL, resp.Auth.LeaseOptions.TTL.Nanoseconds())
return fmt.Errorf("TTL mismatched. Expected: %d Actual: %d", expectedTTL, resp.Auth.LeaseOptions.TTL)
}
return nil
},
Expand Down Expand Up @@ -105,25 +105,25 @@ func TestBackend_basic(t *testing.T) {
}

logicaltest.Test(t, logicaltest.TestCase{
PreCheck: func() { testAccPreCheck(t) },
LogicalBackend: b,
PreCheck: func() { testAccPreCheck(t) },
CredentialBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, false),
testAccMap(t, "default", "fakepol"),
testAccMap(t, "oWnErs", "fakepol"),
testAccLogin(t, []string{"default", "fakepol"}),
testAccLogin(t, []string{"default", "abc", "fakepol"}),
testAccStepConfig(t, true),
testAccMap(t, "default", "fakepol"),
testAccMap(t, "oWnErs", "fakepol"),
testAccLogin(t, []string{"default", "fakepol"}),
testAccLogin(t, []string{"default", "abc", "fakepol"}),
testAccStepConfigWithBaseURL(t),
testAccMap(t, "default", "fakepol"),
testAccMap(t, "oWnErs", "fakepol"),
testAccLogin(t, []string{"default", "fakepol"}),
testAccLogin(t, []string{"default", "abc", "fakepol"}),
testAccMap(t, "default", "fakepol"),
testAccStepConfig(t, true),
mapUserToPolicy(t, os.Getenv("GITHUB_USER"), "userpolicy"),
testAccLogin(t, []string{"default", "fakepol", "userpolicy"}),
testAccLogin(t, []string{"default", "abc", "fakepol", "userpolicy"}),
},
})
}
Expand All @@ -133,6 +133,10 @@ func testAccPreCheck(t *testing.T) {
t.Skip("GITHUB_TOKEN must be set for acceptance tests")
}

if v := os.Getenv("GITHUB_USER"); v == "" {
t.Skip("GITHUB_USER must be set for acceptance tests")
}

if v := os.Getenv("GITHUB_ORG"); v == "" {
t.Skip("GITHUB_ORG must be set for acceptance tests")
}
Expand All @@ -147,7 +151,8 @@ func testAccStepConfig(t *testing.T, upper bool) logicaltest.TestStep {
Operation: logical.UpdateOperation,
Path: "config",
Data: map[string]interface{}{
"organization": os.Getenv("GITHUB_ORG"),
"organization": os.Getenv("GITHUB_ORG"),
"token_policies": []string{"abc"},
},
}
if upper {
Expand Down
143 changes: 93 additions & 50 deletions builtin/credential/github/path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@ import (
"context"
"fmt"
"net/url"
"strings"
"time"

"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/tokenutil"
"github.com/hashicorp/vault/sdk/logical"
)

func pathConfig(b *backend) *framework.Path {
return &framework.Path{
p := &framework.Path{
Pattern: "config",
Fields: map[string]*framework.FieldSchema{
"organization": &framework.FieldSchema{
Expand All @@ -31,18 +33,14 @@ API-compatible authentication server.`,
},
},
"ttl": &framework.FieldSchema{
Type: framework.TypeString,
Description: `Duration after which authentication will be expired`,
DisplayAttrs: &framework.DisplayAttributes{
Name: "TTL",
},
Type: framework.TypeDurationSecond,
Description: tokenutil.DeprecationText("token_ttl"),
Deprecated: true,
},
"max_ttl": &framework.FieldSchema{
Type: framework.TypeString,
Description: `Maximum duration after which authentication will be expired`,
DisplayAttrs: &framework.DisplayAttributes{
Name: "Max TTL",
},
Type: framework.TypeDurationSecond,
Description: tokenutil.DeprecationText("token_max_ttl"),
Deprecated: true,
},
},

Expand All @@ -51,48 +49,77 @@ API-compatible authentication server.`,
logical.ReadOperation: b.pathConfigRead,
},
}

tokenutil.AddTokenFields(p.Fields)
p.Fields["token_policies"].Description += ". This will apply to all tokens generated by this auth method, in addition to any configured for specific users/groups."
jefferai marked this conversation as resolved.
Show resolved Hide resolved
return p
}

func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
organization := data.Get("organization").(string)
baseURL := data.Get("base_url").(string)
if len(baseURL) != 0 {
c, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, err
}
if c == nil {
c = &config{}
}

if organizationRaw, ok := data.GetOk("organization"); ok {
c.Organization = organizationRaw.(string)
}

if baseURLRaw, ok := data.GetOk("base_url"); ok {
baseURL := baseURLRaw.(string)
_, err := url.Parse(baseURL)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error parsing given base_url: %s", err)), nil
}
if !strings.HasSuffix(baseURL, "/") {
baseURL += "/"
}
c.BaseURL = baseURL
}

var ttl time.Duration
var err error
ttlRaw, ok := data.GetOk("ttl")
if !ok || len(ttlRaw.(string)) == 0 {
ttl = 0
} else {
ttl, err = time.ParseDuration(ttlRaw.(string))
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Invalid 'ttl':%s", err)), nil
}
if err := c.ParseTokenFields(req, data); err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}

var maxTTL time.Duration
maxTTLRaw, ok := data.GetOk("max_ttl")
if !ok || len(maxTTLRaw.(string)) == 0 {
maxTTL = 0
} else {
maxTTL, err = time.ParseDuration(maxTTLRaw.(string))
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Invalid 'max_ttl':%s", err)), nil
// Handle upgrade cases
{
ttlRaw, ok := data.GetOk("token_ttl")
if !ok {
ttlRaw, ok = data.GetOk("ttl")
if ok {
c.TTL = time.Duration(ttlRaw.(int)) * time.Second
c.TokenTTL = c.TTL
}
} else {
_, ok = data.GetOk("ttl")
if ok {
c.TTL = c.TokenTTL
} else {
c.TTL = 0
}
}
}

entry, err := logical.StorageEntryJSON("config", config{
Organization: organization,
BaseURL: baseURL,
TTL: ttl,
MaxTTL: maxTTL,
})
maxTTLRaw, ok := data.GetOk("token_max_ttl")
if !ok {
maxTTLRaw, ok = data.GetOk("max_ttl")
if ok {
c.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second
c.TokenMaxTTL = c.MaxTTL
}
} else {
_, ok = data.GetOk("max_ttl")
if ok {
c.MaxTTL = c.TokenMaxTTL
} else {
c.MaxTTL = 0
}
}
}

entry, err := logical.StorageEntryJSON("config", c)
if err != nil {
return nil, err
}
Expand All @@ -109,23 +136,26 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, data
if err != nil {
return nil, err
}

if config == nil {
return nil, fmt.Errorf("configuration object not found")
}

config.TTL /= time.Second
config.MaxTTL /= time.Second
d := map[string]interface{}{
"organization": config.Organization,
"base_url": config.BaseURL,
}
config.PopulateTokenData(d)

resp := &logical.Response{
Data: map[string]interface{}{
"organization": config.Organization,
"base_url": config.BaseURL,
"ttl": config.TTL,
"max_ttl": config.MaxTTL,
},
if config.TTL > 0 {
jefferai marked this conversation as resolved.
Show resolved Hide resolved
d["ttl"] = int64(config.TTL.Seconds())
}
return resp, nil
if config.MaxTTL > 0 {
d["max_ttl"] = int64(config.MaxTTL.Seconds())
}

return &logical.Response{
Data: d,
}, nil
}

// Config returns the configuration for this backend.
Expand All @@ -135,17 +165,30 @@ func (b *backend) Config(ctx context.Context, s logical.Storage) (*config, error
return nil, err
}

if entry == nil {
return nil, nil
}

var result config
if entry != nil {
if err := entry.DecodeJSON(&result); err != nil {
return nil, errwrap.Wrapf("error reading configuration: {{err}}", err)
}
}

if result.TokenTTL == 0 && result.TTL > 0 {
result.TokenTTL = result.TTL
}
if result.TokenMaxTTL == 0 && result.MaxTTL > 0 {
result.TokenMaxTTL = result.MaxTTL
}

return &result, nil
}

type config struct {
tokenutil.TokenParams

Organization string `json:"organization" structs:"organization" mapstructure:"organization"`
BaseURL string `json:"base_url" structs:"base_url" mapstructure:"base_url"`
TTL time.Duration `json:"ttl" structs:"ttl" mapstructure:"ttl"`
Expand Down
Loading