Skip to content

Commit

Permalink
Tokenutilize Okta (#7032)
Browse files Browse the repository at this point in the history
  • Loading branch information
jefferai authored Jul 1, 2019
1 parent 121e3ce commit cafee24
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 47 deletions.
101 changes: 72 additions & 29 deletions builtin/credential/okta/path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/chrismalek/oktasdk-go/okta"
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/tokenutil"
"github.com/hashicorp/vault/sdk/logical"
)

Expand All @@ -19,7 +20,7 @@ const (
)

func pathConfig(b *backend) *framework.Path {
return &framework.Path{
p := &framework.Path{
Pattern: `config`,
Fields: map[string]*framework.FieldSchema{
"organization": &framework.FieldSchema{
Expand Down Expand Up @@ -60,17 +61,13 @@ func pathConfig(b *backend) *framework.Path {
},
"ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: `Duration after which authentication will be expired`,
DisplayAttrs: &framework.DisplayAttributes{
Name: "TTL",
},
Description: tokenutil.DeprecationText("token_ttl"),
Deprecated: true,
},
"max_ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: `Maximum duration after which authentication will be expired`,
DisplayAttrs: &framework.DisplayAttributes{
Name: "Max TTL",
},
Description: tokenutil.DeprecationText("token_max_ttl"),
Deprecated: true,
},
"bypass_okta_mfa": &framework.FieldSchema{
Type: framework.TypeBool,
Expand All @@ -91,6 +88,10 @@ func pathConfig(b *backend) *framework.Path {

HelpSynopsis: pathConfigHelp,
}

tokenutil.AddTokenFields(p.Fields)
p.Fields["token_policies"].Description += ". This will apply to all tokens generated by this auth method, in addition to any configured for specific users/groups."
return p
}

// Config returns the configuration for this backend.
Expand All @@ -110,6 +111,13 @@ func (b *backend) Config(ctx context.Context, s logical.Storage) (*ConfigEntry,
}
}

if result.TokenTTL == 0 && result.TTL > 0 {
result.TokenTTL = result.TTL
}
if result.TokenMaxTTL == 0 && result.MaxTTL > 0 {
result.TokenMaxTTL = result.MaxTTL
}

return &result, nil
}

Expand All @@ -122,20 +130,28 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *f
return nil, nil
}

resp := &logical.Response{
Data: map[string]interface{}{
"organization": cfg.Org,
"org_name": cfg.Org,
"ttl": cfg.TTL.Seconds(),
"max_ttl": cfg.MaxTTL.Seconds(),
"bypass_okta_mfa": cfg.BypassOktaMFA,
},
data := map[string]interface{}{
"organization": cfg.Org,
"org_name": cfg.Org,
"bypass_okta_mfa": cfg.BypassOktaMFA,
}
cfg.PopulateTokenData(data)

if cfg.BaseURL != "" {
resp.Data["base_url"] = cfg.BaseURL
data["base_url"] = cfg.BaseURL
}
if cfg.Production != nil {
resp.Data["production"] = *cfg.Production
data["production"] = *cfg.Production
}
if cfg.TTL > 0 {
data["ttl"] = int64(cfg.TTL.Seconds())
}
if cfg.MaxTTL > 0 {
data["max_ttl"] = int64(cfg.MaxTTL.Seconds())
}

resp := &logical.Response{
Data: data,
}

if cfg.BypassOktaMFA {
Expand Down Expand Up @@ -206,18 +222,43 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *
cfg.BypassOktaMFA = bypass.(bool)
}

ttl, ok := d.GetOk("ttl")
if ok {
cfg.TTL = time.Duration(ttl.(int)) * time.Second
} else if req.Operation == logical.CreateOperation {
cfg.TTL = time.Duration(d.Get("ttl").(int)) * time.Second
if err := cfg.ParseTokenFields(req, d); err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}

maxTTL, ok := d.GetOk("max_ttl")
if ok {
cfg.MaxTTL = time.Duration(maxTTL.(int)) * time.Second
} else if req.Operation == logical.CreateOperation {
cfg.MaxTTL = time.Duration(d.Get("max_ttl").(int)) * time.Second
// Handle upgrade cases
{
ttlRaw, ok := d.GetOk("token_ttl")
if !ok {
ttlRaw, ok = d.GetOk("ttl")
if ok {
cfg.TTL = time.Duration(ttlRaw.(int)) * time.Second
cfg.TokenTTL = cfg.TTL
}
} else {
_, ok = d.GetOk("ttl")
if ok {
cfg.TTL = cfg.TokenTTL
} else {
cfg.TTL = 0
}
}

maxTTLRaw, ok := d.GetOk("token_max_ttl")
if !ok {
maxTTLRaw, ok = d.GetOk("max_ttl")
if ok {
cfg.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second
cfg.TokenMaxTTL = cfg.MaxTTL
}
} else {
_, ok = d.GetOk("max_ttl")
if ok {
cfg.MaxTTL = cfg.TokenMaxTTL
} else {
cfg.MaxTTL = 0
}
}
}

jsonCfg, err := logical.StorageEntryJSON("config", cfg)
Expand Down Expand Up @@ -265,6 +306,8 @@ func (c *ConfigEntry) OktaClient() *okta.Client {

// ConfigEntry for Okta
type ConfigEntry struct {
tokenutil.TokenParams

Org string `json:"organization"`
Token string `json:"token"`
BaseURL string `json:"base_url"`
Expand Down
39 changes: 21 additions & 18 deletions builtin/credential/okta/path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package okta
import (
"context"
"fmt"
"sort"
"strings"

"github.com/go-errors/errors"
Expand Down Expand Up @@ -70,15 +69,12 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
resp = &logical.Response{}
}

sort.Strings(policies)

cfg, err := b.getConfig(ctx, req)
if err != nil {
return nil, err
}

resp.Auth = &logical.Auth{
Policies: policies,
auth := &logical.Auth{
Metadata: map[string]string{
"username": username,
"policies": strings.Join(policies, ","),
Expand All @@ -87,15 +83,18 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
"password": password,
},
DisplayName: username,
LeaseOptions: logical.LeaseOptions{
TTL: cfg.TTL,
MaxTTL: cfg.MaxTTL,
Renewable: true,
},
Alias: &logical.Alias{
Name: username,
},
}
cfg.PopulateTokenAuth(auth)

// Add in configured policies from mappings
if len(policies) > 0 {
auth.Policies = append(auth.Policies, policies...)
}

resp.Auth = auth

for _, groupName := range groupNames {
if groupName == "" {
Expand All @@ -113,23 +112,27 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
username := req.Auth.Metadata["username"]
password := req.Auth.InternalData["password"].(string)

cfg, err := b.getConfig(ctx, req)
if err != nil {
return nil, err
}

loginPolicies, resp, groupNames, err := b.Login(ctx, req, username, password)
if len(loginPolicies) == 0 {
return resp, err
}

if !policyutil.EquivalentPolicies(loginPolicies, req.Auth.TokenPolicies) {
return nil, fmt.Errorf("policies have changed, not renewing")
finalPolicies := cfg.TokenPolicies
if len(loginPolicies) > 0 {
finalPolicies = append(finalPolicies, loginPolicies...)
}

cfg, err := b.getConfig(ctx, req)
if err != nil {
return nil, err
if !policyutil.EquivalentPolicies(finalPolicies, req.Auth.TokenPolicies) {
return nil, fmt.Errorf("policies have changed, not renewing")
}

resp.Auth = req.Auth
resp.Auth.TTL = cfg.TTL
resp.Auth.MaxTTL = cfg.MaxTTL
resp.Auth.TTL = cfg.TokenTTL
resp.Auth.MaxTTL = cfg.TokenMaxTTL

// Remove old aliases
resp.Auth.GroupAliases = nil
Expand Down

0 comments on commit cafee24

Please sign in to comment.