Skip to content

Commit

Permalink
reduce calls to DetermineRoleFromLoginRequest from 3 to 1 for aws aut…
Browse files Browse the repository at this point in the history
…h method (#22583)

* reduce calls to DetermineRoleFromLoginRequest from 3 to 1 for aws auth method

* change ordering of LoginCreateToken args

* replace another determineRoleFromLoginRequest function with role from context

* add changelog

* Check for role in context if not there make call to DeteremineRoleFromLoginRequest

* move context role check below nanmespace check

* Update changelog/22583.txt

Co-authored-by: Nick Cabatoff <[email protected]>

* revert signature to same order

* make sure resp is last argument

* retrieve role from context closer to where role variable is needed

* remove failsafe for role in mfa login

* Update changelog/22583.txt

---------

Co-authored-by: Nick Cabatoff <[email protected]>
  • Loading branch information
elliesterner and ncabatoff authored Aug 28, 2023
1 parent aa05ba6 commit cccfdb0
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 9 deletions.
3 changes: 3 additions & 0 deletions changelog/22583.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
core/quotas: Reduce overhead for role calculation when using cloud auth methods.
```
8 changes: 7 additions & 1 deletion http/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package http

import (
"bytes"
"context"
"errors"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -59,11 +60,16 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler
}
r.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))

role := core.DetermineRoleFromLoginRequestFromBytes(mountPath, bodyBytes, r.Context())

// add an entry to the context to prevent recalculating request role unnecessarily
r = r.WithContext(context.WithValue(r.Context(), logical.CtxKeyRequestRole{}, role))

quotaResp, err := core.ApplyRateLimitQuota(r.Context(), &quotas.Request{
Type: quotas.TypeRateLimit,
Path: path,
MountPath: mountPath,
Role: core.DetermineRoleFromLoginRequestFromBytes(mountPath, bodyBytes, r.Context()),
Role: role,
NamespacePath: ns.Path,
ClientAddress: parseRemoteIPAddress(r),
})
Expand Down
6 changes: 6 additions & 0 deletions sdk/logical/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,9 @@ type CtxKeyInFlightRequestID struct{}
func (c CtxKeyInFlightRequestID) String() string {
return "in-flight-request-ID"
}

type CtxKeyRequestRole struct{}

func (c CtxKeyRequestRole) String() string {
return "request-role"
}
9 changes: 7 additions & 2 deletions vault/login_mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -791,12 +791,17 @@ func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAu
return nil, fmt.Errorf("namespace not found: %w", err)
}

var role string
if reqRole := ctx.Value(logical.CtxKeyRequestRole{}); reqRole != nil {
role = reqRole.(string)
}

// The request successfully authenticated itself. Run the quota checks on
// the original login request path before creating the token.
quotaResp, quotaErr := c.applyLeaseCountQuota(ctx, &quotas.Request{
Path: reqPath,
MountPath: strings.TrimPrefix(mountPoint, ns.Path),
Role: c.DetermineRoleFromLoginRequest(mountPoint, loginRequestData, ctx),
Role: role,
NamespacePath: ns.Path,
})

Expand All @@ -816,7 +821,7 @@ func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAu
// note that we don't need to handle the error for the following function right away.
// The function takes the response as in input variable and modify it. So, the returned
// arguments are resp and err.
leaseGenerated, resp, err := c.LoginCreateToken(ctx, ns, reqPath, mountPoint, resp, loginRequestData)
leaseGenerated, resp, err := c.LoginCreateToken(ctx, ns, reqPath, mountPoint, role, resp)

if quotaResp.Access != nil {
quotaAckErr := c.ackLeaseQuota(quotaResp.Access, leaseGenerated)
Expand Down
28 changes: 22 additions & 6 deletions vault/request_handling.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,10 @@ func (c *Core) switchedLockHandleRequest(httpCtx context.Context, req *logical.R
if ok {
ctx = context.WithValue(ctx, logical.CtxKeyInFlightRequestID{}, inFlightReqID)
}
requestRole, ok := httpCtx.Value(logical.CtxKeyRequestRole{}).(string)
if ok {
ctx = context.WithValue(ctx, logical.CtxKeyRequestRole{}, requestRole)
}
resp, err = c.handleCancelableRequest(ctx, req)
req.SetTokenEntry(nil)
cancel()
Expand Down Expand Up @@ -1248,7 +1252,14 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp
Path: resp.Auth.CreationPath,
NamespaceID: ns.ID,
}
if err := c.expiration.RegisterAuth(ctx, registeredTokenEntry, resp.Auth, c.DetermineRoleFromLoginRequest(req.MountPoint, req.Data, ctx)); err != nil {

// Check for request role
var role string
if reqRole := ctx.Value(logical.CtxKeyRequestRole{}); reqRole != nil {
role = reqRole.(string)
}

if err := c.expiration.RegisterAuth(ctx, registeredTokenEntry, resp.Auth, role); err != nil {
// Best-effort clean up on error, so we log the cleanup error as
// a warning but still return as internal error.
if err := c.tokenStore.revokeOrphan(ctx, resp.Auth.ClientToken); err != nil {
Expand Down Expand Up @@ -1477,12 +1488,18 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re
return
}

// Check for request role
var role string
if reqRole := ctx.Value(logical.CtxKeyRequestRole{}); reqRole != nil {
role = reqRole.(string)
}

// The request successfully authenticated itself. Run the quota checks
// before creating lease.
quotaResp, quotaErr := c.applyLeaseCountQuota(ctx, &quotas.Request{
Path: req.Path,
MountPath: strings.TrimPrefix(req.MountPoint, ns.Path),
Role: c.DetermineRoleFromLoginRequest(req.MountPoint, req.Data, ctx),
Role: role,
NamespacePath: ns.Path,
})

Expand Down Expand Up @@ -1674,7 +1691,7 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re
// Attach the display name, might be used by audit backends
req.DisplayName = auth.DisplayName

leaseGen, respTokenCreate, errCreateToken := c.LoginCreateToken(ctx, ns, req.Path, source, resp, req.Data)
leaseGen, respTokenCreate, errCreateToken := c.LoginCreateToken(ctx, ns, req.Path, source, role, resp)
leaseGenerated = leaseGen
if errCreateToken != nil {
return respTokenCreate, nil, errCreateToken
Expand Down Expand Up @@ -1726,9 +1743,8 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re
// LoginCreateToken creates a token as a result of a login request.
// If MFA is enforced, mfa/validate endpoint calls this functions
// after successful MFA validation to generate the token.
func (c *Core) LoginCreateToken(ctx context.Context, ns *namespace.Namespace, reqPath, mountPoint string, resp *logical.Response, loginRequestData map[string]interface{}) (bool, *logical.Response, error) {
func (c *Core) LoginCreateToken(ctx context.Context, ns *namespace.Namespace, reqPath, mountPoint, role string, resp *logical.Response) (bool, *logical.Response, error) {
auth := resp.Auth

source := strings.TrimPrefix(mountPoint, credentialRoutePrefix)
source = strings.ReplaceAll(source, "/", "-")

Expand Down Expand Up @@ -1788,7 +1804,7 @@ func (c *Core) LoginCreateToken(ctx context.Context, ns *namespace.Namespace, re
}

leaseGenerated := false
err = registerFunc(ctx, tokenTTL, reqPath, auth, c.DetermineRoleFromLoginRequest(mountPoint, loginRequestData, ctx))
err = registerFunc(ctx, tokenTTL, reqPath, auth, role)
switch {
case err == nil:
if auth.TokenType != logical.TokenTypeBatch {
Expand Down

0 comments on commit cccfdb0

Please sign in to comment.